From f6096b1e52aa949fb386b00f4e2853709f655df0 Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Wed, 29 Nov 2023 09:44:46 -0800 Subject: [PATCH 01/10] initial commit --- .gitignore | 163 + .vscode/launch.json | 77 + README.md | 130 + config/templates/AFPredStack.json | 260 + config/templates/CTTPredStack.json | 230 + config/templates/SceneFormerPredStack.json | 358 + .../SceneFormerPredStack_nuplan.json | 358 + diffstack/__init__.py | 0 diffstack/configs/__init__.py | 2 + diffstack/configs/algo_config.py | 247 + diffstack/configs/base.py | 139 + diffstack/configs/config.py | 190 + diffstack/configs/eval_config.py | 126 + diffstack/configs/registry.py | 42 + diffstack/configs/trajdata_config.py | 99 + diffstack/data/__init__.py | 0 diffstack/data/agent_batch_extras.py | 183 + diffstack/data/scene_batch_extras.py | 188 + diffstack/data/trajdata_datamodules.py | 229 + diffstack/data/trajdata_lanes.py | 165 + diffstack/dynamics/__init__.py | 21 + diffstack/dynamics/base.py | 87 + diffstack/dynamics/bicycle.py | 153 + diffstack/dynamics/double_integrator.py | 202 + diffstack/dynamics/single_integrator.py | 52 + diffstack/dynamics/unicycle.py | 834 ++ diffstack/models/CTT.py | 2758 +++++++ diffstack/models/RPE_simple.py | 441 + diffstack/models/Transformer.py | 861 ++ diffstack/models/TypeTransformer.py | 791 ++ diffstack/models/__init__.py | 0 diffstack/models/agentformer.py | 3324 ++++++++ diffstack/models/agentformer_lib.py | 1044 +++ diffstack/models/base_models.py | 1688 ++++ diffstack/models/cnn_roi_encoder.py | 558 ++ diffstack/models/layers.py | 782 ++ diffstack/models/learned_metrics.py | 85 + diffstack/models/unet.py | 512 ++ diffstack/models/vaes.py | 1348 ++++ diffstack/modules/__init__.py | 0 diffstack/modules/metric_models/metrics.py | 321 + diffstack/modules/module.py | 438 + diffstack/modules/predictors/CTT.py | 1673 ++++ diffstack/modules/predictors/__init__.py | 0 diffstack/modules/predictors/factory.py | 69 + .../modules/predictors/kinematic_predictor.py | 163 + .../modules/predictors/tbsim_predictors.py | 477 ++ .../trajectron_utils/environment.py | 79 + .../trajectron_utils/environment/__init__.py | 8 + .../environment/data_structures.py | 276 + .../environment/data_utils.py | 45 + .../environment/environment.py | 79 + .../trajectron_utils/environment/map.py | 229 + .../trajectron_utils/environment/node.py | 265 + .../trajectron_utils/environment/node_type.py | 36 + .../trajectron_utils/environment/scene.py | 228 + .../environment/scene_graph.py | 493 ++ .../trajectron_utils/model/__init__.py | 1 + .../model/components/__init__.py | 4 + .../model/components/additive_attention.py | 67 + .../model/components/discrete_latent.py | 109 + .../model/components/gmm2d.py | 181 + .../model/components/graph_attention.py | 58 + .../model/components/map_encoder.py | 28 + .../model/dataset/__init__.py | 2 + .../trajectron_utils/model/dataset/dataset.py | 314 + .../model/dataset/homography_warper.py | 471 ++ .../model/dataset/preprocessing.py | 757 ++ .../model/dynamics/__init__.py | 4 + .../model/dynamics/dynamic.py | 30 + .../trajectron_utils/model/dynamics/linear.py | 12 + .../model/dynamics/single_integrator.py | 64 + .../model/dynamics/unicycle.py | 274 + .../trajectron_utils/model/mgcvae.py | 1169 +++ .../trajectron_utils/model/model_registrar.py | 70 + .../trajectron_utils/model/model_utils.py | 125 + .../trajectron_utils/model/online/__init__.py | 2 + .../model/online/online_mgcvae.py | 419 + .../model/online/online_trajectron.py | 310 + .../predictors/trajectron_utils/node.py | 265 + .../predictors/trajectron_utils/node_type.py | 36 + .../trajectron_utils/trajectron/__init__.py | 0 .../trajectron_utils/trajectron/trajectron.py | 6 + .../scripts/generate_config_templates.py | 19 + diffstack/scripts/train_pl.py | 582 ++ diffstack/stacks/base.py | 347 + diffstack/stacks/pred_stack.py | 44 + diffstack/stacks/stack_factory.py | 82 + diffstack/utils/__init__.py | 0 diffstack/utils/algo_utils.py | 283 + diffstack/utils/batch_utils.py | 299 + diffstack/utils/bezier_utils.py | 97 + diffstack/utils/config.py | 59 + diffstack/utils/config_utils.py | 132 + diffstack/utils/diffusion_utils.py | 256 + diffstack/utils/dist_utils.py | 680 ++ diffstack/utils/env_utils.py | 478 ++ diffstack/utils/experiment_utils.py | 393 + diffstack/utils/fp16_util.py | 76 + diffstack/utils/geometry_utils.py | 567 ++ diffstack/utils/homotopy.py | 113 + diffstack/utils/kalman_filter.py | 120 + diffstack/utils/l5_utils.py | 726 ++ diffstack/utils/lane_utils.py | 553 ++ diffstack/utils/log_utils.py | 84 + diffstack/utils/loss_utils.py | 858 ++ diffstack/utils/math_utils.py | 59 + diffstack/utils/metrics.py | 848 ++ diffstack/utils/model_registrar.py | 88 + diffstack/utils/model_utils.py | 772 ++ diffstack/utils/planning_utils.py | 722 ++ diffstack/utils/pred_utils.py | 105 + diffstack/utils/rollout_logger.py | 226 + diffstack/utils/sys_utils.py | 12 + diffstack/utils/tensor_utils.py | 1189 +++ diffstack/utils/timer.py | 65 + diffstack/utils/torch_utils.py | 316 + diffstack/utils/tpp_utils.py | 994 +++ diffstack/utils/train_utils.py | 151 + diffstack/utils/trajdata_utils.py | 849 ++ diffstack/utils/tree.py | 102 + diffstack/utils/utils.py | 846 ++ diffstack/utils/vis_utils.py | 394 + diffstack/utils/visualization.py | 915 +++ patches/trajdata_vectorize.patch | 7097 +++++++++++++++++ requirements.txt | 36 + setup.py | 29 + 127 files changed, 51047 insertions(+) create mode 100644 .gitignore create mode 100644 .vscode/launch.json create mode 100644 README.md create mode 100644 config/templates/AFPredStack.json create mode 100644 config/templates/CTTPredStack.json create mode 100644 config/templates/SceneFormerPredStack.json create mode 100644 config/templates/SceneFormerPredStack_nuplan.json create mode 100644 diffstack/__init__.py create mode 100644 diffstack/configs/__init__.py create mode 100644 diffstack/configs/algo_config.py create mode 100644 diffstack/configs/base.py create mode 100644 diffstack/configs/config.py create mode 100644 diffstack/configs/eval_config.py create mode 100644 diffstack/configs/registry.py create mode 100644 diffstack/configs/trajdata_config.py create mode 100644 diffstack/data/__init__.py create mode 100644 diffstack/data/agent_batch_extras.py create mode 100644 diffstack/data/scene_batch_extras.py create mode 100644 diffstack/data/trajdata_datamodules.py create mode 100644 diffstack/data/trajdata_lanes.py create mode 100644 diffstack/dynamics/__init__.py create mode 100644 diffstack/dynamics/base.py create mode 100644 diffstack/dynamics/bicycle.py create mode 100644 diffstack/dynamics/double_integrator.py create mode 100644 diffstack/dynamics/single_integrator.py create mode 100644 diffstack/dynamics/unicycle.py create mode 100644 diffstack/models/CTT.py create mode 100644 diffstack/models/RPE_simple.py create mode 100644 diffstack/models/Transformer.py create mode 100644 diffstack/models/TypeTransformer.py create mode 100644 diffstack/models/__init__.py create mode 100644 diffstack/models/agentformer.py create mode 100644 diffstack/models/agentformer_lib.py create mode 100644 diffstack/models/base_models.py create mode 100644 diffstack/models/cnn_roi_encoder.py create mode 100644 diffstack/models/layers.py create mode 100644 diffstack/models/learned_metrics.py create mode 100644 diffstack/models/unet.py create mode 100644 diffstack/models/vaes.py create mode 100644 diffstack/modules/__init__.py create mode 100644 diffstack/modules/metric_models/metrics.py create mode 100644 diffstack/modules/module.py create mode 100644 diffstack/modules/predictors/CTT.py create mode 100644 diffstack/modules/predictors/__init__.py create mode 100644 diffstack/modules/predictors/factory.py create mode 100644 diffstack/modules/predictors/kinematic_predictor.py create mode 100644 diffstack/modules/predictors/tbsim_predictors.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/data_structures.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/data_utils.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/environment.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/map.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/node.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/node_type.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/scene.py create mode 100644 diffstack/modules/predictors/trajectron_utils/environment/scene_graph.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/additive_attention.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/discrete_latent.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/graph_attention.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/components/map_encoder.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dataset/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dataset/homography_warper.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dynamics/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dynamics/dynamic.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dynamics/linear.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dynamics/single_integrator.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/mgcvae.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/model_registrar.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/model_utils.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/online/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/online/online_mgcvae.py create mode 100644 diffstack/modules/predictors/trajectron_utils/model/online/online_trajectron.py create mode 100644 diffstack/modules/predictors/trajectron_utils/node.py create mode 100644 diffstack/modules/predictors/trajectron_utils/node_type.py create mode 100644 diffstack/modules/predictors/trajectron_utils/trajectron/__init__.py create mode 100644 diffstack/modules/predictors/trajectron_utils/trajectron/trajectron.py create mode 100644 diffstack/scripts/generate_config_templates.py create mode 100644 diffstack/scripts/train_pl.py create mode 100644 diffstack/stacks/base.py create mode 100644 diffstack/stacks/pred_stack.py create mode 100644 diffstack/stacks/stack_factory.py create mode 100644 diffstack/utils/__init__.py create mode 100644 diffstack/utils/algo_utils.py create mode 100644 diffstack/utils/batch_utils.py create mode 100644 diffstack/utils/bezier_utils.py create mode 100644 diffstack/utils/config.py create mode 100644 diffstack/utils/config_utils.py create mode 100644 diffstack/utils/diffusion_utils.py create mode 100644 diffstack/utils/dist_utils.py create mode 100644 diffstack/utils/env_utils.py create mode 100644 diffstack/utils/experiment_utils.py create mode 100644 diffstack/utils/fp16_util.py create mode 100644 diffstack/utils/geometry_utils.py create mode 100644 diffstack/utils/homotopy.py create mode 100644 diffstack/utils/kalman_filter.py create mode 100644 diffstack/utils/l5_utils.py create mode 100644 diffstack/utils/lane_utils.py create mode 100644 diffstack/utils/log_utils.py create mode 100644 diffstack/utils/loss_utils.py create mode 100644 diffstack/utils/math_utils.py create mode 100644 diffstack/utils/metrics.py create mode 100644 diffstack/utils/model_registrar.py create mode 100644 diffstack/utils/model_utils.py create mode 100644 diffstack/utils/planning_utils.py create mode 100644 diffstack/utils/pred_utils.py create mode 100644 diffstack/utils/rollout_logger.py create mode 100644 diffstack/utils/sys_utils.py create mode 100644 diffstack/utils/tensor_utils.py create mode 100644 diffstack/utils/timer.py create mode 100644 diffstack/utils/torch_utils.py create mode 100644 diffstack/utils/tpp_utils.py create mode 100644 diffstack/utils/train_utils.py create mode 100644 diffstack/utils/trajdata_utils.py create mode 100644 diffstack/utils/tree.py create mode 100644 diffstack/utils/utils.py create mode 100644 diffstack/utils/vis_utils.py create mode 100644 diffstack/utils/visualization.py create mode 100644 patches/trajdata_vectorize.patch create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cac3b3f --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +*.csv +*.gif +*.png +*.zip +wandb/ + +*.pkl +*.pickle + + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Old code dump +/old/ + +# Experiment results +/experiments/ + +/cache/ + +# Mac OSX +.DS_Store +._.DS_Store + +settings.json \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..b96296a --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,77 @@ +{ + + "version": "0.2.0", + "configurations": [ + { + "name": "PL train CTT", + "type": "python", + "request": "launch", + "program": "diffstack/scripts/train_pl.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--config_file=${workspaceFolder}/config/templates/CTTPredStack.json", + "--remove_exp_dir", + // "--debug", + "--dataset_path=", + ] + }, + { + "name": "PL train agentformer", + "type": "python", + "request": "launch", + "program": "diffstack/scripts/train_pl.py", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYTHONPATH": "${workspaceFolder}${pathSeparator}${env:PYTHONPATH}", + }, + "args": [ + "--config_file=${workspaceFolder}/config/templates/AFPredStack.json", + "--remove_exp_dir", + // "--debug", + "--dataset_path=", + ] + }, + { + "name": "PL eval CTT", + "type": "python", + "request": "launch", + "program": "diffstack/scripts/train_pl.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--config_file=", + "--ckpt_path=", + "--test_data=", + "--test_data_root=", + "--evaluate", + "--log_image_frequency=10", + "--eval_output_dir=", + "--test_batch_size=16", + // "--log_all_image", + "--dataset_path=", + ] + }, + { + "name": "PL eval agentformer", + "type": "python", + "request": "launch", + "program": "diffstack/scripts/train_pl.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--config_file=", + "--ckpt_path=", + "--test_data=", + "--test_data_root=", + "--evaluate", + "--log_image_frequency=10", + "--eval_output_dir=", + "--test_batch_size=16", + // "--log_all_image", + "--dataset_path=", + ] + } + ] +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..154510b --- /dev/null +++ b/README.md @@ -0,0 +1,130 @@ +# Differentiable Stack + +Impements Categorical Traffic Transformer in the environment of diffstack. + +Paper [pdf](link) + +## Setup + +Clone the repo with the desired branch. Use `--recurse-submodules` to also clone various submodules + +For trajdata, we need to use branch `vectorize`, there are two options: + +1. clone from NVlabs and then apply a patch + +``` +git clone --recurse-submodules --branch main git@github.com:NVlabs/trajdata.git; +cd trajdata; +git apply ../patches/trajdata_vectorize.patch +cd .. +``` + +2. clone from a forked repo of trajdata + + +``` +git clone --recurse-submodules --branch vectorize git@github.com:chenyx09/trajdata.git +``` + +Then add Pplan + +``` +git clone --recurse-submodules git@github.com:NVlabs/spline-planner.git + +``` + +You can also sync submodules later using +``` +git submodule update --remote +``` + +### Install diffstack + +We will install diffstack with a conda env. + +Create a `conda` environment for `diffstack`: + +``` +conda create -n diffstack python=3.9 +conda activate diffstack +``` + +Next install torch pytorch compatible to your CUDA setup following [Pytorch website](https://pytorch.org/get-started/locally/) + + + +Install the required python packages for diffstack + +``` +pip install -r requirements.txt +``` + +Install submodules manually (use `-e` for developer mode) +``` +pip install -e ./trajdata +pip install -e ./spline-planner +``` + +These additional steps might be necessary +``` +# need to reinstall pathos, gets replaced by multiprocessing install +pip uninstall pathos -y +pip install pathos==0.2.9 + +# Sometimes you need to reinstall matplotlib with the correct version + +``` +pip install matplotlib==3.3.4 +``` + +# On Mac sometimes we need to reinstall torch +conda install pytorch torchvision torchaudio -c pytorch + + + +# Fix opencv compatibility issue https://github.com/opencv/opencv-python/issues/591 +pip uninstall opencv-python opencv-python-headless -y +pip install "opencv-python-headless==4.2.0.34" +# pip install "opencv-python-headless==4.7.0.72" # for python 3.9 +``` + +### Data + +CTT uses [trajdata](https://github.com/NVlabs/trajdata) as the dataloader, technically, you can train with any dataset supported by trajdata. Considering the vectorized map support, we have tested CTT with WOMD, nuScenes, and nuPlan. + + +## Generating config templates + +``` +python diffstack/scripts/generate_config_templates.py +``` + +## Training and eval + +Training script: + +``` +python diffstack/scripts/train_pl.py +--config_file=/config/templates/CTTPredStack.json +--remove_exp_dir +--dataset_path= +``` + +Eval script: + +``` +python diffstack/scripts/train_pl.py +--evaluate +--config_file= +--ckpt_path= +--test_data= +--test_data_root= +--log_image_frequency=10 +--eval_output_dir= +--test_batch_size=16 +--dataset_path= +``` + +Training and eval example commands are also included in the `.vscode/launch.json` file. + + diff --git a/config/templates/AFPredStack.json b/config/templates/AFPredStack.json new file mode 100644 index 0000000..7a5a731 --- /dev/null +++ b/config/templates/AFPredStack.json @@ -0,0 +1,260 @@ +{ + "registered_name": "AFPredStack", + "train": { + "logging": { + "terminal_output_to_txt": true, + "log_tb": false, + "log_wandb": true, + "wandb_project_name": "diffstack", + "log_every_n_steps": 10, + "flush_every_n_steps": 100 + }, + "save": { + "enabled": true, + "every_n_steps": 1000, + "best_k": 10, + "save_best_rollout": false, + "save_best_validation": true + }, + "rollout": { + "save_video": true, + "enabled": false, + "every_n_steps": 5000, + "warm_start_n_steps": 1 + }, + "training": { + "batch_size": 100, + "num_steps": 200000, + "num_data_workers": 8 + }, + "validation": { + "enabled": true, + "batch_size": 32, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + }, + "parallel_strategy": "ddp", + "rebuild_cache": false, + "on_ngc": false, + "amp": false, + "auto_batch_size": false, + "max_batch_size": 1000, + "gradient_clip_val": 0.5, + "trajdata_source_train": "train", + "trajdata_source_valid": "val", + "trajdata_source_test": null, + "trajdata_source_root": "nusc_trainval", + "trajdata_val_source_root": null, + "trajdata_test_source_root": null, + "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", + "datamodule_class": "UnifiedDataModule", + "ego_only": true, + "test": { + "enabled": false, + "batch_size": 32, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + } + }, + "env": { + "name": "nusc_trainval", + "rasterizer": { + "raster_size": 224, + "pixel_size": 0.5, + "ego_center": [ + -0.75, + 0.0 + ] + }, + "data_generation_params": { + "other_agents_num": 10, + "max_agents_distance": 30, + "yaw_correction_speed": 1.0 + }, + "simulation": { + "distance_th_close": 30, + "num_simulation_steps": 50, + "start_frame_index": 0, + "vectorize_lane": "ego" + }, + "incl_neighbor_map": false, + "incl_vector_map": true, + "incl_raster_map": false, + "remove_single_successor": false, + "calc_lane_graph": true, + "max_num_lanes": 15, + "num_lane_pts": 32, + "remove_parked": false + }, + "stack": { + "predictor": { + "name": "agentformer", + "seed": 1, + "load_map": false, + "dynamic_type": "Unicycle", + "step_time": 0.1, + "history_num_frames": 10, + "future_num_frames": 20, + "traj_scale": 10, + "nz": 32, + "sample_k": 4, + "tf_model_dim": 256, + "tf_ff_dim": 512, + "tf_nhead": 8, + "tf_dropout": 0.1, + "z_tau": { + "start": 0.5, + "finish": 0.0001, + "decay": 0.5 + }, + "input_type": [ + "scene_norm", + "vel", + "heading" + ], + "fut_input_type": [ + "scene_norm", + "vel", + "heading" + ], + "dec_input_type": [ + "heading" + ], + "pred_type": "dynamic", + "sn_out_type": "norm", + "sn_out_heading": false, + "pos_concat": true, + "rand_rot_scene": false, + "use_map": true, + "pooling": "mean", + "agent_enc_shuffle": false, + "vel_heading": false, + "max_agent_len": 128, + "agent_enc_learn": false, + "use_agent_enc": false, + "motion_dim": 2, + "forecast_dim": 2, + "z_type": "gaussian", + "nlayer": 6, + "ar_detach": true, + "pred_scale": 1.0, + "pos_offset": false, + "learn_prior": true, + "discrete_rot": false, + "map_global_rot": false, + "ar_train": true, + "max_train_agent": 100, + "num_eval_samples": 5, + "UAC": false, + "loss_cfg": { + "kld": { + "min_clip": 1.0 + }, + "sample": { + "weight": 1.0, + "k": 20 + } + }, + "loss_weights": { + "prediction_loss": 1.0, + "kl_loss": 1.0, + "collision_loss": 3.0, + "EC_collision_loss": 5.0, + "diversity_loss": 0.3, + "deviation_loss": 0.1 + }, + "scene_orig_all_past": false, + "conn_dist": 100000.0, + "scene_centric": true, + "stage": 2, + "num_frames_per_stage": 10, + "ego_conditioning": true, + "perturb": { + "enabled": true, + "N_pert": 1, + "OU": { + "theta": 0.8, + "sigma": 2.0, + "scale": [ + 1.0, + 0.3 + ] + } + }, + "map_encoder": { + "model_architecture": "resnet18", + "image_shape": [ + 3, + 224, + 224 + ], + "feature_dim": 32, + "spatial_softmax": { + "enabled": false, + "kwargs": { + "num_kp": 32, + "temperature": 1.0, + "learnable_temperature": false + } + } + }, + "context_encoder": { + "nlayer": 2 + }, + "future_decoder": { + "nlayer": 2, + "out_mlp_dim": [ + 512, + 256 + ] + }, + "future_encoder": { + "nlayer": 2 + }, + "optim_params": { + "policy": { + "learning_rate": { + "initial": 0.0001, + "decay_factor": 0.1, + "epoch_schedule": [] + }, + "regularization": { + "L2": 0.0 + } + } + } + }, + "stack_type": "pred" + }, + "eval": { + "name": null, + "env": "nusc", + "dataset_path": null, + "eval_class": "", + "seed": 0, + "ckpt_root_dir": "checkpoints/", + "ckpt": { + "dir": null, + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "eval": { + "batch_size": 100, + "num_steps": null, + "num_data_workers": 8 + }, + "log_image_frequency": null, + "log_all_image": false, + "trajdata_source_root": "nusc_trainval", + "trajdata_source_eval": "val" + }, + "name": "test", + "root_dir": "predictor_agentformer_trained_models/", + "seed": 1, + "devices": { + "num_gpus": 1 + } +} \ No newline at end of file diff --git a/config/templates/CTTPredStack.json b/config/templates/CTTPredStack.json new file mode 100644 index 0000000..7d1ea83 --- /dev/null +++ b/config/templates/CTTPredStack.json @@ -0,0 +1,230 @@ +{ + "registered_name": "CTTPredStack", + "train": { + "logging": { + "terminal_output_to_txt": true, + "log_tb": false, + "log_wandb": true, + "wandb_project_name": "diffstack", + "log_every_n_steps": 10, + "flush_every_n_steps": 100 + }, + "save": { + "enabled": true, + "every_n_steps": 1000, + "best_k": 10, + "save_best_rollout": false, + "save_best_validation": true + }, + "rollout": { + "save_video": true, + "enabled": false, + "every_n_steps": 5000, + "warm_start_n_steps": 1 + }, + "training": { + "batch_size": 2, + "num_steps": 200000, + "num_data_workers": 8 + }, + "validation": { + "enabled": true, + "batch_size": 2, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + }, + "parallel_strategy": "ddp", + "rebuild_cache": false, + "on_ngc": false, + "amp": false, + "auto_batch_size": false, + "max_batch_size": 1000, + "gradient_clip_val": 0.5, + "trajdata_source_train": "train", + "trajdata_source_valid": "val", + "trajdata_source_test": null, + "trajdata_source_root": "nusc_trainval", + "trajdata_val_source_root": null, + "trajdata_test_source_root": null, + "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", + "datamodule_class": "UnifiedDataModule", + "ego_only": true, + "test": { + "enabled": false, + "batch_size": 32, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + } + }, + "env": { + "name": "nusc_trainval", + "rasterizer": { + "raster_size": 224, + "pixel_size": 0.5, + "ego_center": [ + -0.75, + 0.0 + ] + }, + "data_generation_params": { + "other_agents_num": 10, + "max_agents_distance": 30, + "yaw_correction_speed": 1.0 + }, + "simulation": { + "distance_th_close": 30, + "num_simulation_steps": 50, + "start_frame_index": 0, + "vectorize_lane": "ego" + }, + "incl_neighbor_map": false, + "incl_vector_map": true, + "incl_raster_map": false, + "remove_single_successor": false, + "calc_lane_graph": true, + "max_num_lanes": 15, + "num_lane_pts": 30, + "remove_parked": false + }, + "stack": { + "predictor": { + "name": "CTT", + "step_time": 0.25, + "history_num_frames": 4, + "future_num_frames": 12, + "n_embd": 256, + "n_head": 8, + "use_rpe_net": true, + "PE_mode": "RPE", + "enc_nblock": 3, + "dec_nblock": 3, + "encoder": { + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "edge_scale": 20, + "edge_clip": [ + -4, + 4 + ], + "mode_embed_dim": 64, + "jm_GNN_nblock": 2, + "num_joint_samples": 30, + "num_joint_factors": 6, + "null_lane_mode": true + }, + "edge_dim": { + "a2a": 14, + "a2l": 12, + "l2a": 12, + "l2l": 16 + }, + "a2l_edge_type": "attn", + "a2l_n_embd": 64, + "attn_ntype": { + "a2a": 2, + "a2l": 1, + "l2l": 2 + }, + "lane_GNN_num_layers": 4, + "homotpy_GNN_num_layers": 4, + "hist_lane_relation": "LaneRelation", + "fut_lane_relation": "SimpleLaneRelation", + "classify_a2l_4all_lanes": false, + "closed_loop": false, + "CL_Tf_mode": 6, + "CL_step": 2, + "decoder": { + "arch": "mlp", + "lstm_hidden_size": 128, + "mlp_hidden_dims": [ + 128, + 256 + ], + "traj_dim": 4, + "num_layers": 2, + "dyn": { + "vehicle": "unicycle", + "pedestrian": "DI_unicycle" + }, + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "decode_num_modes": 5, + "AR_step_size": 1, + "GNN_enabled": false, + "AR_update_mode": null, + "dec_rounds": 5 + }, + "num_lane_pts": 30, + "loss_weights": { + "marginal_lm_loss": 5.0, + "marginal_homo_loss": 5.0, + "joint_prob_loss": 5.0, + "xy_loss": 4.0, + "yaw_loss": 1.0, + "lm_consistency_loss": 5.0, + "homotopy_consistency_loss": 5.0, + "yaw_dev_loss": 30.0, + "y_dev_loss": 5.0, + "l2_reg": 0.1, + "coll_loss": 0.1, + "acce_reg_loss": 0.04, + "steering_reg_loss": 0.1, + "input_violation_loss": 20.0, + "jerk_loss": 0.05 + }, + "loss": { + "lm_margin_offset": 0.2 + }, + "weighted_consistency_loss": false, + "LR_sample_hack": true, + "scene_centric": true, + "max_joint_cardinality": 5, + "optim_params": { + "policy": { + "learning_rate": { + "initial": 0.0001, + "decay_factor": 0.1, + "epoch_schedule": [] + }, + "regularization": { + "L2": 0.0 + } + } + } + }, + "stack_type": "pred" + }, + "eval": { + "name": null, + "env": "nusc", + "dataset_path": null, + "eval_class": "", + "seed": 0, + "ckpt_root_dir": "checkpoints/", + "ckpt": { + "dir": null, + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "eval": { + "batch_size": 100, + "num_steps": null, + "num_data_workers": 8 + }, + "log_image_frequency": null, + "log_all_image": false, + "trajdata_source_root": "nusc_trainval", + "trajdata_source_eval": "val" + }, + "name": "test", + "root_dir": "predictor_CTT_trained_models/", + "seed": 1, + "devices": { + "num_gpus": 1 + } +} \ No newline at end of file diff --git a/config/templates/SceneFormerPredStack.json b/config/templates/SceneFormerPredStack.json new file mode 100644 index 0000000..eade21a --- /dev/null +++ b/config/templates/SceneFormerPredStack.json @@ -0,0 +1,358 @@ +{ + "registered_name": "SceneFormerPredStack", + "train": { + "logging": { + "terminal_output_to_txt": true, + "log_tb": false, + "log_wandb": true, + "wandb_project_name": "diffstack", + "log_every_n_steps": 10, + "flush_every_n_steps": 100 + }, + "save": { + "enabled": true, + "every_n_steps": 500, + "best_k": 10, + "save_best_rollout": false, + "save_best_validation": true + }, + "rollout": { + "save_video": true, + "enabled": false, + "every_n_steps": 5000, + "warm_start_n_steps": 1 + }, + "training": { + "batch_size": 24, + "num_steps": 1000000, + "num_data_workers": 64 + }, + "validation": { + "enabled": true, + "batch_size": 24, + "num_data_workers": 64, + "every_n_steps": 400, + "num_steps_per_epoch": 20 + }, + "parallel_strategy": "ddp", + "rebuild_cache": false, + "on_ngc": false, + "trajdata_source_train": "train", + "trajdata_source_valid": "val", + "trajdata_source_root": "nusc_trainval", + "trajdata_val_source_root": null, + "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", + "datamodule_class": "UnifiedDataModule", + "ego_only": true, + "amp": true, + "auto_batch_size": false, + "max_batch_size": 36, + "gradient_clip_val": 0.5 + }, + "env": { + "name": "nusc_trainval", + "rasterizer": { + "raster_size": 224, + "pixel_size": 0.5, + "ego_center": [ + -0.75, + 0.0 + ] + }, + "data_generation_params": { + "other_agents_num": 11, + "max_agents_distance": 40, + "yaw_correction_speed": 1.0 + }, + "simulation": { + "distance_th_close": 30, + "num_simulation_steps": 50, + "start_frame_index": 0, + "vectorize_lane": "ego" + }, + "incl_neighbor_map": false, + "incl_vector_map": true, + "incl_raster_map": false, + "calc_lane_graph": true, + "max_num_lanes": 32, + "num_lane_pts": 32, + "remove_single_successor": true + }, + "stack": { + "predictor": { + "name": "sceneformer", + "step_time": 0.25, + "history_num_frames": 6, + "future_num_frames": 12, + "n_embd": 128, + "n_head": 4, + "PE_mode": "PE", + "use_rpe_net": false, + "enc_nblock": 2, + "dec_nblock": 2, + "edge_dim": { + "a2a": 14, + "a2l": 12, + "l2a": 12, + "l2l": 16 + }, + "a2l_edge_type": "proj", + "a2l_n_embd": 64, + "attn_ntype": { + "a2a": 2, + "a2l": 1, + "l2l": 2 + }, + "lane_GNN_num_layers": 4, + "homotopy_GNN_num_layers": 4, + "closed_loop": false, + "CL_Tf_mode": 6, + "CL_step": 2, + "encoder": { + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "edge_scale": 20, + "edge_clip": [ + -4, + 4 + ], + "mode_embed_dim": 64, + "jm_GNN_nblock": 2, + "num_joint_samples": 30, + "num_joint_factors": 7, + "null_lane_mode": true + }, + "decoder": { + "arch": "mlp", + "lstm_hidden_size": 128, + "mlp_hidden_dims": [ + 128, + 256 + ], + "traj_dim": 4, + "num_layers": 2, + "dyn": { + "vehicle": "unicycle", + "pedestrian": "DI_unicycle" + }, + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "decode_num_modes": 3, + "AR_step_size": 1, + "GNN_enabled": false, + "AR_update_mode": "step", + "dec_rounds": 5 + }, + "num_lane_pts": 32, + "hist_lane_relation": "LaneRelation", + "fut_lane_relation": "SimpleLaneRelation", + "classify_a2l_4all_lanes": false, + "max_joint_cardinality": 5, + "loss_weights": { + "marginal_lm_loss": 5.0, + "marginal_homo_loss": 5.0, + "joint_prob_loss": 5.0, + "xy_loss": 2.0, + "heading_loss": 1.0, + "l2_reg": 0.0001, + "lm_consistency_loss": 5.0, + "homotopy_consistency_loss": 5.0, + "coll_loss": 2.0, + "acce_reg_loss": 0.05, + "steering_reg_loss": 0.2, + "input_violation_loss": 20.0, + "jerk_loss": 0.1 + }, + "loss": { + "lm_margin_offset": 0.2 + }, + "weighted_consistency_loss": false, + "LR_sample_hack": true, + "scene_centric": true, + "optim_params": { + "policy": { + "learning_rate": { + "initial": 0.0001, + "decay_factor": 0.05, + "epoch_schedule": [] + }, + "regularization": { + "L2": 0.0 + } + } + } + }, + "name": "pred" + }, + "eval": { + "name": null, + "env": "nusc", + "dataset_path": null, + "eval_class": "", + "seed": 0, + "num_scenes_per_batch": 4, + "num_scenes_to_evaluate": 100, + "num_episode_repeats": 3, + "start_frame_index_each_episode": null, + "seed_each_episode": null, + "ego_only": false, + "agent_eval_class": null, + "ckpt_root_dir": "checkpoints/", + "experience_hdf5_path": null, + "results_dir": "results/", + "ckpt": { + "policy": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "planner": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "predictor": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "cvae_metric": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "occupancy_metric": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + } + }, + "policy": { + "mask_drivable": true, + "num_plan_samples": 50, + "num_action_samples": 10, + "pos_to_yaw": true, + "yaw_correction_speed": 1.0, + "diversification_clearance": null, + "sample": false, + "cost_weights": { + "collision_weight": 10.0, + "lane_weight": 1.0, + "likelihood_weight": 0.0, + "progress_weight": 0.0 + } + }, + "metrics": { + "compute_analytical_metrics": true, + "compute_learned_metrics": false + }, + "perturb": { + "enabled": false, + "OU": { + "theta": 0.8, + "sigma": [ + 0.0, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 4.0 + ], + "scale": [ + 1.0, + 1.0, + 0.2 + ] + } + }, + "rolling_perturb": { + "enabled": false, + "OU": { + "theta": 0.8, + "sigma": 0.5, + "scale": [ + 1.0, + 1.0, + 0.2 + ] + } + }, + "occupancy": { + "rolling": true, + "rolling_horizon": [ + 5, + 10, + 20 + ] + }, + "cvae": { + "rolling": true, + "rolling_horizon": [ + 5, + 10, + 20 + ] + }, + "nusc": { + "eval_scenes": [ + 0, + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90 + ], + "n_step_action": 5, + "num_simulation_steps": 200, + "skip_first_n": 0 + }, + "l5kit": { + "eval_scenes": [ + 9058, + 5232, + 14153, + 8173, + 10314, + 7027, + 9812, + 1090, + 9453, + 978, + 10263, + 874, + 5563, + 9613, + 261, + 2826, + 2175, + 9977, + 6423, + 1069 + ], + "n_step_action": 5, + "num_simulation_steps": 200, + "skip_first_n": 1, + "skimp_rollout": false + }, + "adjustment": { + "random_init_plan": true, + "remove_existing_neighbors": false, + "initial_num_neighbors": 4, + "num_frame_per_new_agent": 20 + } + }, + "stack_type": "pred", + "name": "test", + "root_dir": "predictor_sceneformer_trained_models/", + "seed": 1, + "devices": { + "num_gpus": 1 + } +} \ No newline at end of file diff --git a/config/templates/SceneFormerPredStack_nuplan.json b/config/templates/SceneFormerPredStack_nuplan.json new file mode 100644 index 0000000..f1c3337 --- /dev/null +++ b/config/templates/SceneFormerPredStack_nuplan.json @@ -0,0 +1,358 @@ +{ + "registered_name": "SceneFormerPredStack", + "train": { + "logging": { + "terminal_output_to_txt": true, + "log_tb": false, + "log_wandb": true, + "wandb_project_name": "diffstack", + "log_every_n_steps": 10, + "flush_every_n_steps": 100 + }, + "save": { + "enabled": true, + "every_n_steps": 3000, + "best_k": 10, + "save_best_rollout": false, + "save_best_validation": true + }, + "rollout": { + "save_video": true, + "enabled": false, + "every_n_steps": 5000, + "warm_start_n_steps": 1 + }, + "training": { + "batch_size": 20, + "num_steps": 1000000, + "num_data_workers": 64 + }, + "validation": { + "enabled": true, + "batch_size": 20, + "num_data_workers": 64, + "every_n_steps": 1500, + "num_steps_per_epoch": 200 + }, + "parallel_strategy": "ddp", + "rebuild_cache": false, + "on_ngc": false, + "trajdata_source_train": "train", + "trajdata_source_valid": "val", + "trajdata_source_root": "nuplan_train", + "trajdata_val_source_root": "nuplan_val", + "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", + "datamodule_class": "UnifiedDataModule", + "ego_only": true, + "amp": true, + "auto_batch_size": false, + "max_batch_size": 36, + "gradient_clip_val": 0.5 + }, + "env": { + "name": "nusc_trainval", + "rasterizer": { + "raster_size": 224, + "pixel_size": 0.5, + "ego_center": [ + -0.75, + 0.0 + ] + }, + "data_generation_params": { + "other_agents_num": 11, + "max_agents_distance": 40, + "yaw_correction_speed": 1.0 + }, + "simulation": { + "distance_th_close": 30, + "num_simulation_steps": 50, + "start_frame_index": 0, + "vectorize_lane": "ego" + }, + "incl_neighbor_map": false, + "incl_vector_map": true, + "incl_raster_map": false, + "calc_lane_graph": true, + "max_num_lanes": 32, + "num_lane_pts": 32, + "remove_single_successor": true + }, + "stack": { + "predictor": { + "name": "sceneformer", + "step_time": 0.25, + "history_num_frames": 6, + "future_num_frames": 12, + "n_embd": 128, + "n_head": 4, + "PE_mode": "PE", + "use_rpe_net": false, + "enc_nblock": 2, + "dec_nblock": 2, + "edge_dim": { + "a2a": 14, + "a2l": 12, + "l2a": 12, + "l2l": 16 + }, + "a2l_edge_type": "proj", + "a2l_n_embd": 64, + "attn_ntype": { + "a2a": 2, + "a2l": 1, + "l2l": 2 + }, + "lane_GNN_num_layers": 4, + "homotopy_GNN_num_layers": 4, + "closed_loop": false, + "CL_Tf_mode": 6, + "CL_step": 2, + "encoder": { + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "edge_scale": 20, + "edge_clip": [ + -4, + 4 + ], + "mode_embed_dim": 64, + "jm_GNN_nblock": 2, + "num_joint_samples": 30, + "num_joint_factors": 7, + "null_lane_mode": true + }, + "decoder": { + "arch": "mlp", + "lstm_hidden_size": 128, + "mlp_hidden_dims": [ + 128, + 256 + ], + "traj_dim": 4, + "num_layers": 2, + "dyn": { + "vehicle": "unicycle", + "pedestrian": "DI_unicycle" + }, + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "decode_num_modes": 3, + "AR_step_size": 1, + "GNN_enabled": false, + "AR_update_mode": "step", + "dec_rounds": 5 + }, + "num_lane_pts": 32, + "hist_lane_relation": "LaneRelation", + "fut_lane_relation": "SimpleLaneRelation", + "classify_a2l_4all_lanes": false, + "max_joint_cardinality": 5, + "loss_weights": { + "marginal_lm_loss": 5.0, + "marginal_homo_loss": 5.0, + "joint_prob_loss": 5.0, + "xy_loss": 4.0, + "heading_loss": 1.0, + "l2_reg": 0.0001, + "lm_consistency_loss": 5.0, + "homotopy_consistency_loss": 5.0, + "coll_loss": 2.0, + "acce_reg_loss": 0.05, + "steering_reg_loss": 0.2, + "input_violation_loss": 20.0, + "jerk_loss": 0.1 + }, + "loss": { + "lm_margin_offset": 0.2 + }, + "weighted_consistency_loss": false, + "LR_sample_hack": true, + "scene_centric": true, + "optim_params": { + "policy": { + "learning_rate": { + "initial": 0.0001, + "decay_factor": 0.05, + "epoch_schedule": [] + }, + "regularization": { + "L2": 0.0 + } + } + } + }, + "name": "pred" + }, + "eval": { + "name": null, + "env": "nusc", + "dataset_path": null, + "eval_class": "", + "seed": 0, + "num_scenes_per_batch": 4, + "num_scenes_to_evaluate": 100, + "num_episode_repeats": 3, + "start_frame_index_each_episode": null, + "seed_each_episode": null, + "ego_only": false, + "agent_eval_class": null, + "ckpt_root_dir": "checkpoints/", + "experience_hdf5_path": null, + "results_dir": "results/", + "ckpt": { + "policy": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "planner": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "predictor": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "cvae_metric": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "occupancy_metric": { + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + } + }, + "policy": { + "mask_drivable": true, + "num_plan_samples": 50, + "num_action_samples": 10, + "pos_to_yaw": true, + "yaw_correction_speed": 1.0, + "diversification_clearance": null, + "sample": false, + "cost_weights": { + "collision_weight": 10.0, + "lane_weight": 1.0, + "likelihood_weight": 0.0, + "progress_weight": 0.0 + } + }, + "metrics": { + "compute_analytical_metrics": true, + "compute_learned_metrics": false + }, + "perturb": { + "enabled": false, + "OU": { + "theta": 0.8, + "sigma": [ + 0.0, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 4.0 + ], + "scale": [ + 1.0, + 1.0, + 0.2 + ] + } + }, + "rolling_perturb": { + "enabled": false, + "OU": { + "theta": 0.8, + "sigma": 0.5, + "scale": [ + 1.0, + 1.0, + 0.2 + ] + } + }, + "occupancy": { + "rolling": true, + "rolling_horizon": [ + 5, + 10, + 20 + ] + }, + "cvae": { + "rolling": true, + "rolling_horizon": [ + 5, + 10, + 20 + ] + }, + "nusc": { + "eval_scenes": [ + 0, + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90 + ], + "n_step_action": 5, + "num_simulation_steps": 200, + "skip_first_n": 0 + }, + "l5kit": { + "eval_scenes": [ + 9058, + 5232, + 14153, + 8173, + 10314, + 7027, + 9812, + 1090, + 9453, + 978, + 10263, + 874, + 5563, + 9613, + 261, + 2826, + 2175, + 9977, + 6423, + 1069 + ], + "n_step_action": 5, + "num_simulation_steps": 200, + "skip_first_n": 1, + "skimp_rollout": false + }, + "adjustment": { + "random_init_plan": true, + "remove_existing_neighbors": false, + "initial_num_neighbors": 4, + "num_frame_per_new_agent": 20 + } + }, + "stack_type": "pred", + "name": "test", + "root_dir": "predictor_sceneformer_trained_models/", + "seed": 1, + "devices": { + "num_gpus": 4 + } +} \ No newline at end of file diff --git a/diffstack/__init__.py b/diffstack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/configs/__init__.py b/diffstack/configs/__init__.py new file mode 100644 index 0000000..2c33122 --- /dev/null +++ b/diffstack/configs/__init__.py @@ -0,0 +1,2 @@ +from diffstack.configs.base import ExperimentConfig + diff --git a/diffstack/configs/algo_config.py b/diffstack/configs/algo_config.py new file mode 100644 index 0000000..e26fc71 --- /dev/null +++ b/diffstack/configs/algo_config.py @@ -0,0 +1,247 @@ +import math + +from diffstack.configs.base import AlgoConfig + + +class BehaviorCloningConfig(AlgoConfig): + def __init__(self): + super(BehaviorCloningConfig, self).__init__() + self.eval_class = "BC" + + self.name = "bc" + self.model_architecture = "resnet18" + self.map_feature_dim = 256 + self.history_num_frames = 8 + self.history_num_frames_ego = 8 + self.history_num_frames_agents = 8 + self.future_num_frames = 6 + self.step_time = 0.5 + self.render_ego_history = False + + self.decoder.layer_dims = () + self.decoder.state_as_input = True + + self.dynamics.type = "Unicycle" + self.dynamics.max_steer = 0.5 + self.dynamics.max_yawvel = math.pi * 2.0 + self.dynamics.acce_bound = (-10, 8) + self.dynamics.ddh_bound = (-math.pi * 2.0, math.pi * 2.0) + self.dynamics.max_speed = 40.0 # roughly 90mph + + self.spatial_softmax.enabled = False + self.spatial_softmax.kwargs.num_kp = 32 + self.spatial_softmax.kwargs.temperature = 1.0 + self.spatial_softmax.kwargs.learnable_temperature = False + + self.loss_weights.prediction_loss = 1.0 + self.loss_weights.goal_loss = 0.0 + self.loss_weights.collision_loss = 0.0 + self.loss_weights.yaw_reg_loss = 0.001 + + self.optim_params.policy.learning_rate.initial = 1e-3 # policy learning rate + self.optim_params.policy.learning_rate.decay_factor = ( + 0.1 # factor to decay LR by (if epoch schedule non-empty) + ) + self.optim_params.policy.learning_rate.epoch_schedule = ( + [] + ) # epochs where LR decay occurs + self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength + self.checkpoint.enabled = False + self.checkpoint.path = None + + +class AgentFormerConfig(AlgoConfig): + def __init__(self): + super(AgentFormerConfig, self).__init__() + self.name = "agentformer" + self.seed = 1 + self.load_map = False + self.dynamic_type = "Unicycle" + self.step_time = 0.1 + self.history_num_frames = 10 + self.future_num_frames = 20 + self.traj_scale = 10 + self.nz = 32 + self.sample_k = 4 + self.tf_model_dim = 256 + self.tf_ff_dim = 512 + self.tf_nhead = 8 + self.tf_dropout = 0.1 + self.z_tau.start = 0.5 + self.z_tau.finish = 0.0001 + self.z_tau.decay = 0.5 + self.input_type = ["scene_norm", "vel", "heading"] + self.fut_input_type = ["scene_norm", "vel", "heading"] + self.dec_input_type = ["heading"] + self.pred_type = "dynamic" + self.sn_out_type = "norm" + self.sn_out_heading = False + self.pos_concat = True + self.rand_rot_scene = False + self.use_map = True + self.pooling = "mean" + self.agent_enc_shuffle = False + self.vel_heading = False + self.max_agent_len = 128 + self.agent_enc_learn = False + self.use_agent_enc = False + self.motion_dim = 2 + self.forecast_dim = 2 + self.z_type = "gaussian" + self.nlayer = 6 + self.ar_detach = True + self.pred_scale = 1.0 + self.pos_offset = False + self.learn_prior = True + self.discrete_rot = False + self.map_global_rot = False + self.ar_train = True + self.max_train_agent = 100 + self.num_eval_samples = 5 + + self.UAC = False # compare unconditional and conditional prediction + + self.loss_cfg.kld.min_clip = 1.0 + self.loss_cfg.sample.weight = 1.0 + self.loss_cfg.sample.k = 20 + self.loss_weights.prediction_loss = 1.0 + self.loss_weights.kl_loss = 1.0 + self.loss_weights.collision_loss = 3.0 + self.loss_weights.EC_collision_loss = 5.0 + self.loss_weights.diversity_loss = 0.3 + self.loss_weights.deviation_loss = 0.1 + self.scene_orig_all_past = False + self.conn_dist = 100000.0 + self.scene_centric = True + self.stage = 2 + self.num_frames_per_stage = 10 + + self.ego_conditioning = True + self.perturb.enabled = True + self.perturb.N_pert = 1 + self.perturb.OU.theta = 0.8 + self.perturb.OU.sigma = 2.0 + self.perturb.OU.scale = [1.0, 0.3] + + self.map_encoder.model_architecture = "resnet18" + self.map_encoder.image_shape = [3, 224, 224] + self.map_encoder.feature_dim = 32 + self.map_encoder.spatial_softmax.enabled = False + self.map_encoder.spatial_softmax.kwargs.num_kp = 32 + self.map_encoder.spatial_softmax.kwargs.temperature = 1.0 + self.map_encoder.spatial_softmax.kwargs.learnable_temperature = False + + self.context_encoder.nlayer = 2 + + self.future_decoder.nlayer = 2 + self.future_decoder.out_mlp_dim = [512, 256] + self.future_encoder.nlayer = 2 + + self.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate + self.optim_params.policy.learning_rate.decay_factor = ( + 0.1 # factor to decay LR by (if epoch schedule non-empty) + ) + self.optim_params.policy.learning_rate.epoch_schedule = ( + [] + ) # epochs where LR decay occurs + self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength + + +class CTTConfig(AlgoConfig): + def __init__(self): + super(CTTConfig, self).__init__() + self.name = "CTT" + self.step_time = 0.25 + self.history_num_frames = 4 + self.future_num_frames = 12 + + self.n_embd = 256 + self.n_head = 8 + self.use_rpe_net = True + self.PE_mode = "RPE" # "RPE" or "PE" + + self.enc_nblock = 3 + self.dec_nblock = 3 + + self.encoder.attn_pdrop = 0.05 + self.encoder.resid_pdrop = 0.05 + self.encoder.pooling = "attn" + self.encoder.edge_scale = 20 + self.encoder.edge_clip = [-4, 4] + self.encoder.mode_embed_dim = 64 + self.encoder.jm_GNN_nblock = 2 + self.encoder.num_joint_samples = 30 + self.encoder.num_joint_factors = 6 + self.encoder.null_lane_mode = True + + self.edge_dim.a2a = 14 + self.edge_dim.a2l = 12 + self.edge_dim.l2a = 12 + self.edge_dim.l2l = 16 + self.a2l_edge_type = "attn" + self.a2l_n_embd = 64 + + self.attn_ntype.a2a = 2 + self.attn_ntype.a2l = 1 + self.attn_ntype.l2l = 2 + + self.lane_GNN_num_layers = 4 + self.homotpy_GNN_num_layers = 4 + + self.hist_lane_relation = "LaneRelation" + self.fut_lane_relation = "SimpleLaneRelation" + self.classify_a2l_4all_lanes = ( + False # Alternative is 4all modes, find resulting lane + ) + + self.decoder.arch = "mlp" + self.decoder.lstm_hidden_size = 128 + self.decoder.mlp_hidden_dims = [128, 256] + self.decoder.traj_dim = 4 # x,y,v,yaw + self.decoder.num_layers = 2 + self.decoder.dyn.vehicle = "unicycle" + self.decoder.dyn.pedestrian = "DI_unicycle" + self.decoder.attn_pdrop = 0.05 + self.decoder.resid_pdrop = 0.05 + self.decoder.pooling = "attn" + self.decoder.decode_num_modes = 5 + self.decoder.AR_step_size = 1 + self.decoder.GNN_enabled = False + self.decoder.AR_update_mode = None + self.decoder.dec_rounds = 5 + + self.num_lane_pts = 30 + + self.loss_weights.marginal_lm_loss = 5.0 + self.loss_weights.marginal_homo_loss = 5.0 + self.loss_weights.joint_prob_loss = 5.0 + self.loss_weights.xy_loss = 4.0 + self.loss_weights.yaw_loss = 1.0 + self.loss_weights.lm_consistency_loss = 5.0 + self.loss_weights.homotopy_consistency_loss = 5.0 + self.loss_weights.yaw_dev_loss = 30.0 + self.loss_weights.y_dev_loss = 5.0 + self.loss_weights.l2_reg = 0.1 + self.loss_weights.coll_loss = 0.1 + self.loss_weights.acce_reg_loss = 0.04 + self.loss_weights.steering_reg_loss = 0.1 + self.loss_weights.input_violation_loss = 20.0 + self.loss_weights.jerk_loss = 0.05 + + self.loss.lm_margin_offset = 0.2 + + self.weighted_consistency_loss = False + self.LR_sample_hack = True + + self.scene_centric = True + + self.max_joint_cardinality = 5 + + self.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate + self.optim_params.policy.learning_rate.decay_factor = ( + 0.1 # factor to decay LR by (if epoch schedule non-empty) + ) + self.optim_params.policy.learning_rate.epoch_schedule = ( + [] + ) # epochs where LR decay occurs + self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength diff --git a/diffstack/configs/base.py b/diffstack/configs/base.py new file mode 100644 index 0000000..31edf6f --- /dev/null +++ b/diffstack/configs/base.py @@ -0,0 +1,139 @@ +from diffstack.configs.config import Dict +from copy import deepcopy +from diffstack.configs.eval_config import EvaluationConfig + + +class TrainConfig(Dict): + def __init__(self): + super(TrainConfig, self).__init__() + self.logging.terminal_output_to_txt = True # whether to log stdout to txt file + self.logging.log_tb = False # enable tensorboard logging + self.logging.log_wandb = True # enable wandb logging + self.logging.wandb_project_name = "diffstack" + self.logging.log_every_n_steps = 10 + self.logging.flush_every_n_steps = 100 + + ## save config - if and when to save model checkpoints ## + self.save.enabled = True # whether model saving should be enabled or disabled + self.save.every_n_steps = 100 # save model every n epochs + self.save.best_k = 5 + self.save.save_best_rollout = False + self.save.save_best_validation = True + + ## evaluation rollout config ## + self.rollout.save_video = True + self.rollout.enabled = False # enable evaluation rollouts + self.rollout.every_n_steps = 1000 # do rollouts every @rate epochs + self.rollout.warm_start_n_steps = ( + 1 # number of steps to wait before starting rollouts + ) + + ## training config + self.training.batch_size = 100 + self.training.num_steps = 200000 + self.training.num_data_workers = 0 + + ## validation config + self.validation.enabled = True + self.validation.batch_size = 100 + self.validation.num_data_workers = 0 + self.validation.every_n_steps = 1000 + self.validation.num_steps_per_epoch = 100 + + ## Training parallelism (e.g., multi-GPU) + self.parallel_strategy = "ddp" + + self.rebuild_cache = False + + self.on_ngc = False + + # AMP + self.amp = False + + # auto batch size + self.auto_batch_size = False + self.max_batch_size = 1000 + # graidient clipping + self.gradient_clip_val = 0.5 + + +class EnvConfig(Dict): + def __init__(self): + super(EnvConfig, self).__init__() + self.name = "my_env" + + +class AlgoConfig(Dict): + def __init__(self): + super(AlgoConfig, self).__init__() + self.name = "my_algo" + + +class ExperimentConfig(Dict): + def __init__( + self, + train_config: TrainConfig, + env_config: EnvConfig, + module_configs: dict, + eval_config: EvaluationConfig = None, + registered_name: str = None, + stack_type: str = None, + name: str = None, + root_dir=None, + seed=None, + devices=None, + ): + """ + + Args: + train_config (TrainConfig): training config + env_config (EnvConfig): environment config + module_configs dict(AlgoConfig): algorithm configs for all modules + registered_name (str): name of the experiment config object in the global config registry + """ + super(ExperimentConfig, self).__init__() + self.registered_name = registered_name + + self.train = train_config + self.env = env_config + self.stack = module_configs + self.stack.stack_type = stack_type + self.eval = EvaluationConfig() if eval_config is None else eval_config + + # Write all results to this directory. A new folder with the timestamp will be created + # in this directory, and it will contain three subfolders - "log", "models", and "videos". + # The "log" directory will contain tensorboard and stdout txt logs. The "models" directory + # will contain saved model checkpoints. The "videos" directory contains evaluation rollout + # videos. + self.name = ( + "test" # name of the experiment (creates a subdirectory under root_dir) + ) + stack_name = "" + for key, config in self.stack.items(): + if isinstance(config, dict): + stack_name += key + "_" + config["name"] + + self.root_dir = ( + "{}_trained_models/".format(stack_name) if root_dir is None else root_dir + ) + self.seed = ( + 1 if seed is None else seed + ) # seed for everything (for reproducibility) + + self.devices = ( + Dict(num_gpus=1) if devices is None else devices + ) # Set to 0 to use CPU + + def clone(self): + return self.__class__( + train_config=deepcopy(self.train), + env_config=deepcopy(self.env), + module_configs=deepcopy(self.stack), + eval_config=deepcopy(self.eval), + registered_name=self.registered_name, + stack_type=self.stack.stack_type, + name=self.name, + root_dir=self.root_dir, + seed=self.seed, + devices=self.devices, + ) diff --git a/diffstack/configs/config.py b/diffstack/configs/config.py new file mode 100644 index 0000000..9dec7a4 --- /dev/null +++ b/diffstack/configs/config.py @@ -0,0 +1,190 @@ +""" +Basic config class - provides a convenient way to work with nested +dictionaries (by exposing keys as attributes) and to save / load from jsons. + +Based on addict: https://github.com/mewwts/addict +""" + +import json +import copy +import contextlib +from copy import deepcopy + + +class Dict(dict): + + def __init__(__self, *args, **kwargs): + object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) + object.__setattr__(__self, '__key', kwargs.pop('__key', None)) + object.__setattr__(__self, '__frozen', False) + for arg in args: + if not arg: + continue + elif isinstance(arg, dict): + for key, val in arg.items(): + __self[key] = __self._hook(val) + elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): + __self[arg[0]] = __self._hook(arg[1]) + else: + for key, val in iter(arg): + __self[key] = __self._hook(val) + + for key, val in kwargs.items(): + __self[key] = __self._hook(val) + + def __setattr__(self, name, value): + if hasattr(self.__class__, name): + raise AttributeError("'Dict' object attribute " + "'{0}' is read-only".format(name)) + else: + self[name] = value + + def __setitem__(self, name, value): + isFrozen = (hasattr(self, '__frozen') and + object.__getattribute__(self, '__frozen')) + if isFrozen and name not in super(Dict, self).keys(): + raise KeyError(name) + super(Dict, self).__setitem__(name, value) + try: + p = object.__getattribute__(self, '__parent') + key = object.__getattribute__(self, '__key') + except AttributeError: + p = None + key = None + if p is not None: + p[key] = self + object.__delattr__(self, '__parent') + object.__delattr__(self, '__key') + + def __add__(self, other): + if not self.keys(): + return other + else: + self_type = type(self).__name__ + other_type = type(other).__name__ + msg = "unsupported operand type(s) for +: '{}' and '{}'" + raise TypeError(msg.format(self_type, other_type)) + + @classmethod + def _hook(cls, item): + if isinstance(item, dict): + return cls(item) + elif isinstance(item, (list, tuple)): + return type(item)(cls._hook(elem) for elem in item) + return item + + def __getattr__(self, item): + return self.__getitem__(item) + + def __missing__(self, name): + if object.__getattribute__(self, '__frozen'): + raise KeyError(name) + return Dict(__parent=self, __key=name) + + def __delattr__(self, name): + del self[name] + + def __repr__(self): + json_string = json.dumps(self.to_dict(), indent=4) + return json_string + + def to_dict(self): + base = {} + for key, value in self.items(): + if isinstance(value, type(self)): + base[key] = value.to_dict() + elif isinstance(value, (list, tuple)): + base[key] = type(value)( + item.to_dict() if isinstance(item, type(self)) else + item for item in value) + else: + base[key] = value + return base + + def copy(self): + return copy.copy(self) + + def deepcopy(self): + return copy.deepcopy(self) + + def __deepcopy__(self, memo): + other = self.__class__() + memo[id(self)] = other + for key, value in self.items(): + other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) + return other + + def update(self, *args, **kwargs): + other = {} + if args: + if len(args) > 1: + raise TypeError() + other.update(args[0]) + other.update(kwargs) + for k, v in other.items(): + if ((k not in self) or + (not isinstance(self[k], dict)) or + (not isinstance(v, dict))): + self[k] = v + else: + self[k].update(v) + + def __getnewargs__(self): + return tuple(self.items()) + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + + def __or__(self, other): + if not isinstance(other, (Dict, dict)): + return NotImplemented + new = Dict(self) + new.update(other) + return new + + def __ror__(self, other): + if not isinstance(other, (Dict, dict)): + return NotImplemented + new = Dict(other) + new.update(self) + return new + + def __ior__(self, other): + self.update(other) + return self + + def setdefault(self, key, default=None): + if key in self: + return self[key] + else: + self[key] = default + return default + + def lock(self, should_lock=True): + object.__setattr__(self, '__frozen', should_lock) + for key, val in self.items(): + if isinstance(val, Dict): + val.lock(should_lock) + + def dump(self, filename = None): + json_string = json.dumps(self.to_dict(), indent=4) + if filename is not None: + f = open(filename, "w") + f.write(json_string) + f.close() + return json_string + + def unlock(self): + self.lock(False) + + @contextlib.contextmanager + def unlocked(self): + self.unlock() + yield + self.lock() + + def clone(self): + return deepcopy(self) \ No newline at end of file diff --git a/diffstack/configs/eval_config.py b/diffstack/configs/eval_config.py new file mode 100644 index 0000000..cb479c4 --- /dev/null +++ b/diffstack/configs/eval_config.py @@ -0,0 +1,126 @@ +import numpy as np +from copy import deepcopy + +from diffstack.configs.config import Dict + + +class SimEvaluationConfig(Dict): + def __init__(self): + super(SimEvaluationConfig, self).__init__() + self.name = None + self.env = "nusc" + self.dataset_path = None + self.eval_class = "" + self.seed = 0 + self.num_scenes_per_batch = 4 + self.num_scenes_to_evaluate = 100 + + self.num_episode_repeats = 3 + self.start_frame_index_each_episode = None # if specified, should be the same length as num_episode_repeats + self.seed_each_episode = None # if specified, should be the same length as num_episode_repeats + + self.ego_only = False + self.agent_eval_class = None + + self.ckpt_root_dir = "checkpoints/" + self.experience_hdf5_path = None + self.results_dir = "results/" + + self.ckpt.policy.ngc_job_id = None + self.ckpt.policy.ckpt_dir = None + self.ckpt.policy.ckpt_key = None + + self.ckpt.planner.ngc_job_id = None + self.ckpt.planner.ckpt_dir = None + self.ckpt.planner.ckpt_key = None + + self.ckpt.predictor.ngc_job_id = None + self.ckpt.predictor.ckpt_dir = None + self.ckpt.predictor.ckpt_key = None + + self.ckpt.cvae_metric.ngc_job_id = None + self.ckpt.cvae_metric.ckpt_dir = None + self.ckpt.cvae_metric.ckpt_key = None + + self.ckpt.occupancy_metric.ngc_job_id = None + self.ckpt.occupancy_metric.ckpt_dir = None + self.ckpt.occupancy_metric.ckpt_key = None + + self.policy.mask_drivable = True + self.policy.num_plan_samples = 50 + self.policy.num_action_samples = 10 + self.policy.pos_to_yaw = True + self.policy.yaw_correction_speed = 1.0 + self.policy.diversification_clearance = None + self.policy.sample = True + + + self.policy.cost_weights.collision_weight = 10.0 + self.policy.cost_weights.lane_weight = 1.0 + self.policy.cost_weights.likelihood_weight = 0.0 # 0.1 + self.policy.cost_weights.progress_weight = 0.0 # 0.005 + + self.metrics.compute_analytical_metrics = True + self.metrics.compute_learned_metrics = False + + self.perturb.enabled = False + self.perturb.OU.theta = 0.8 + self.perturb.OU.sigma = [0.0, 0.1,0.2,0.5,1.0,2.0,4.0] + self.perturb.OU.scale = [1.0,1.0,0.2] + + self.rolling_perturb.enabled = False + self.rolling_perturb.OU.theta = 0.8 + self.rolling_perturb.OU.sigma = 0.5 + self.rolling_perturb.OU.scale = [1.0,1.0,0.2] + + self.occupancy.rolling = True + self.occupancy.rolling_horizon = [5,10,20] + + self.cvae.rolling = True + self.cvae.rolling_horizon = [5,10,20] + + self.nusc.eval_scenes = np.arange(100).tolist() + self.nusc.n_step_action = 5 + self.nusc.num_simulation_steps = 200 + self.nusc.skip_first_n = 0 + + + self.adjustment.random_init_plan=True + self.adjustment.remove_existing_neighbors = False + self.adjustment.initial_num_neighbors = 4 + self.adjustment.num_frame_per_new_agent = 20 + + def clone(self): + return deepcopy(self) + +class EvaluationConfig(Dict): + def __init__(self): + super(EvaluationConfig, self).__init__() + self.name = None + self.env = "nusc" + self.dataset_path = None + self.eval_class = "" + self.seed = 0 + self.ckpt_root_dir = "checkpoints/" + self.ckpt.dir = None + self.ckpt.ngc_job_id = None + self.ckpt.ckpt_dir = None + self.ckpt.ckpt_key = None + + self.eval.batch_size = 100 + self.eval.num_steps = None + self.eval.num_data_workers = 8 + self.log_image_frequency=None + self.log_all_image = False + + self.trajdata_source_root = "nusc_trainval" + self.trajdata_source_eval = "val" + + +class TrainTimeEvaluationConfig(SimEvaluationConfig): + def __init__(self): + super(TrainTimeEvaluationConfig, self).__init__() + + self.num_scenes_per_batch = 4 + self.nusc.eval_scenes = np.arange(0, 100, 10).tolist() + self.policy.sample = False diff --git a/diffstack/configs/registry.py b/diffstack/configs/registry.py new file mode 100644 index 0000000..e3e9ce8 --- /dev/null +++ b/diffstack/configs/registry.py @@ -0,0 +1,42 @@ +"""A global registry for looking up named experiment configs""" +from diffstack.configs.base import ExperimentConfig +from diffstack.configs.config import Dict + + +from diffstack.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig + +from diffstack.configs.algo_config import ( + AgentFormerConfig, + CTTConfig, +) + + +EXP_CONFIG_REGISTRY = dict() + + +EXP_CONFIG_REGISTRY["AFPredStack"] = ExperimentConfig( + train_config=TrajdataTrainConfig(), + env_config=TrajdataEnvConfig(), + module_configs=Dict(predictor=AgentFormerConfig()), + registered_name="AFPredStack", + stack_type="pred", +) + + +EXP_CONFIG_REGISTRY["CTTPredStack"] = ExperimentConfig( + train_config=TrajdataTrainConfig(), + env_config=TrajdataEnvConfig(), + module_configs=Dict(predictor=CTTConfig()), + registered_name="CTTPredStack", + stack_type="pred", +) + + +def get_registered_experiment_config(registered_name): + if registered_name not in EXP_CONFIG_REGISTRY.keys(): + raise KeyError( + "'{}' is not a registered experiment config please choose from {}".format( + registered_name, list(EXP_CONFIG_REGISTRY.keys()) + ) + ) + return EXP_CONFIG_REGISTRY[registered_name].clone() diff --git a/diffstack/configs/trajdata_config.py b/diffstack/configs/trajdata_config.py new file mode 100644 index 0000000..def7b9c --- /dev/null +++ b/diffstack/configs/trajdata_config.py @@ -0,0 +1,99 @@ +from diffstack.configs.base import TrainConfig, EnvConfig + +MAX_POINTS_LANE = 5 + + +class TrajdataTrainConfig(TrainConfig): + def __init__(self): + super(TrajdataTrainConfig, self).__init__() + + self.trajdata_source_train = "train" + self.trajdata_source_valid = "val" + self.trajdata_source_test = None + self.trajdata_source_root = "nusc_trainval" + self.trajdata_val_source_root = None + self.trajdata_test_source_root = None + # self.trajdata_source_train = "mini_train" + # self.trajdata_source_valid = "mini_val" + # self.trajdata_source_root = "nusc_mini" + + self.dataset_path = "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS" + self.datamodule_class = "UnifiedDataModule" + self.ego_only = True + + self.rollout.enabled = False + self.rollout.save_video = True + self.rollout.every_n_steps = 5000 + + # training config + self.training.batch_size = 100 + self.training.num_steps = 200000 + self.training.num_data_workers = 8 + + self.save.every_n_steps = 1000 + self.save.best_k = 10 + + # validation config + self.validation.enabled = True + self.validation.batch_size = 32 + self.validation.num_data_workers = 6 + self.validation.every_n_steps = 500 + self.validation.num_steps_per_epoch = 50 + + self.test.enabled = False + self.test.batch_size = 32 + self.test.num_data_workers = 6 + self.test.every_n_steps = 500 + self.test.num_steps_per_epoch = 50 + + +class TrajdataEnvConfig(EnvConfig): + def __init__(self): + super(TrajdataEnvConfig, self).__init__() + + self.name = "nusc_trainval" + + # raster image size [pixels] + self.rasterizer.raster_size = 224 + + # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. + self.rasterizer.pixel_size = 0.5 + + # where the agent is on the map, (0.0, 0.0) is the center + # WARNING: this should not be changed before resolving + self.rasterizer.ego_center = (-0.75, 0.0) + + # maximum number of agents to consider during training + self.data_generation_params.other_agents_num = 10 + + self.data_generation_params.max_agents_distance = 30 + + # correct for yaw (zero-out delta yaw) when speed is lower than this threshold + self.data_generation_params.yaw_correction_speed = 1.0 + + self.simulation.distance_th_close = 30 + + # maximum number of simulation steps to run (0.1sec / step) + self.simulation.num_simulation_steps = 50 + + # which frame to start an simulation episode with + self.simulation.start_frame_index = 0 + + # whether to get lane information + self.simulation.vectorize_lane = "ego" + + # whether include neighbor map patches + self.incl_neighbor_map = False + # whether include vectorized map + self.incl_vector_map = True + # whether include rasterized map + self.incl_raster_map = False + # whether to combine lanes with a single successor + self.remove_single_successor = False # be very careful, don't turn this on unless you know what you are doing + # whether to prepare lane graphs + self.calc_lane_graph = True + # maximum number of lanes to consider + self.max_num_lanes = 15 + # number of lane points to include in the polyline + self.num_lane_pts = 30 + self.remove_parked = False diff --git a/diffstack/data/__init__.py b/diffstack/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/data/agent_batch_extras.py b/diffstack/data/agent_batch_extras.py new file mode 100644 index 0000000..9ae2b62 --- /dev/null +++ b/diffstack/data/agent_batch_extras.py @@ -0,0 +1,183 @@ +import torch +import numpy as np +from typing import Dict, Iterable, Union + +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.data_structures.agent import AgentType + +from diffstack.utils.utils import move_list_element_to_front +from diffstack.data.trajdata_lanes import get_goal_lanes, get_lane_projection_points, LanesList + + +def robot_selector(element: AgentBatchElement): + # Find most relevant neighbor + dists = [] + inds = [] + for n_i in range(len(element.neighbor_futures)): + # Filter vehicles + if element.neighbor_types_np[n_i] != AgentType.VEHICLE: + continue + + # Filter incomplete future + if element.neighbor_futures[n_i].shape[0] < element.agent_future_np.shape[0]: + continue + + # Filter parked vehicles + # We used to do this for all dataset variants EXCEPT for v7 + # + # # Implementation 1: compute distance for valid history and future + # # Agent is considered to be parked if it moves less then 1m from the beginning of history to the end of future. + # start_to_end_dist = np.linalg.norm(element.neighbor_histories[n_i][0, :2] - element.neighbor_futures[n_i][-1, :2]) + # if start_to_end_dist < 1.: + # continue + # + # Implementation 2: use pre-computed metainfo based on entire valid trajectory + if element.neighbor_meta_dicts[n_i]['is_stationary']: + continue + + # Distance from predicted agent + dist = np.square(element.agent_future_np[:, :2] - element.neighbor_futures[n_i][:, :2]) # [1:] exclude current state for vehicle future + dist = np.min(dist.sum(axis=-1), axis=-1) # sum over states, min over time + inds.append(n_i) + dists.append(dist) + + if dists: + plan_i = inds[np.argmin(np.array(dists))] # neighbor that gets closest to current node + else: + # No neighbors or all futures are incomplete + plan_i = -1 + + + element.extras['robot_ind'] = plan_i + return element + + +def make_robot_the_first(element: AgentBatchElement): + """Reorder neighbors such that the first neighbor is the proxy robot agent for planning.""" + + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, do nothing + pass + else: + element.neighbor_futures = move_list_element_to_front(element.neighbor_futures, robot_ind) + element.neighbor_future_lens_np = np.array(move_list_element_to_front(element.neighbor_future_lens_np, robot_ind)) + element.neighbor_future_extents = move_list_element_to_front(element.neighbor_future_extents, robot_ind) + element.neighbor_histories = move_list_element_to_front(element.neighbor_histories, robot_ind) + element.neighbor_history_lens_np = np.array(move_list_element_to_front(element.neighbor_history_lens_np, robot_ind)) + element.neighbor_history_extents = move_list_element_to_front(element.neighbor_history_extents, robot_ind) + element.neighbor_types_np = np.array(move_list_element_to_front(element.neighbor_types_np, robot_ind)) + element.neighbor_meta_dicts = move_list_element_to_front(element.neighbor_meta_dicts, robot_ind) + element.extras["robot_ind"] = 0 + return element + + +def remove_parked(element: AgentBatchElement): + is_parked = np.array([meta_dict['is_stationary'] for meta_dict in element.neighbor_meta_dicts]) + get_filtered_list = lambda x: [x[i] for i in range(element.num_neighbors) if not is_parked[i]] + get_filtered_np = lambda x: x[np.logical_not(is_parked)] + + element.neighbor_histories = get_filtered_list(element.neighbor_histories) + element.neighbor_history_extents = get_filtered_list(element.neighbor_history_extents) + element.neighbor_history_lens_np = get_filtered_np(element.neighbor_history_lens_np) + + element.neighbor_futures = get_filtered_list(element.neighbor_futures) + element.neighbor_future_extents = get_filtered_list(element.neighbor_future_extents) + element.neighbor_future_lens_np = get_filtered_np(element.neighbor_future_lens_np) + + element.neighbor_meta_dicts = get_filtered_list(element.neighbor_meta_dicts) + element.neighbor_types_np = get_filtered_np(element.neighbor_types_np) + + element.num_neighbors = len(element.neighbor_types_np) + + return element + + +def augment_with_point_goal(element: AgentBatchElement): + """Add goal information for planning.""" + + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy goal info + goal = np.full((8, ), np.nan, dtype=np.float32) + else: + # Goal is the gt state at the end of the planning horizon. + goal = element.neighbor_futures[robot_ind][-1].astype(np.float32) + element.extras["goal"] = goal + return element + + +def augment_with_goal_lanes(element: AgentBatchElement, goal_to_lane_range: float = 20., max_lateral_dist: float = 4.5, max_heading_delta: float = np.pi/4): + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy lane info + goal_lanes = LanesList([]) + else: + goal_lanes = get_goal_lanes( + element.vec_map, element.extras['goal'], element.agent_from_world_tf, + goal_to_lane_range=goal_to_lane_range, max_lateral_dist=max_lateral_dist, max_heading_delta=max_heading_delta) + + element.extras["lanes_near_goal"] = goal_lanes + return element + + +def augment_with_lanes(element: AgentBatchElement, make_missing_lane_invalid: bool = True): + """Add lane information for planning.""" + + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy lane info + lane_projection_points = None + else: + lane_projection_points = get_lane_projection_points( + element.vec_map, + element.neighbor_histories[robot_ind], element.neighbor_futures[robot_ind], + element.agent_from_world_tf) + if lane_projection_points is None: + lane_projection_points = np.full((element.agent_future_len + 1, 3), np.nan, dtype=np.float32) + if make_missing_lane_invalid: + # We set robot_idx to -1 to indicate that the sample is invalid + element.extras['robot_ind'] = -1 + + element.extras["lane_projection_points"] = lane_projection_points + + return element + + +def get_filter_func(ego_valid=False, pred_near_ego=False, lane_near_ego=False, pred_not_parked=False): + """ + """ + # shortcut no filtering + if not ego_valid and not pred_near_ego and not lane_near_ego and not pred_not_parked: + return None + + def filter_fn(element: AgentBatchElement) -> bool: + if ego_valid and element.extras['robot_ind'] < 0: + return False + + if pred_not_parked: + # # Implementation 1: compute distance for valid history and future + # start_to_end_dist = np.linalg.norm(element.agent_history_np[0, :2] - element.agent_future_np[-1, :2]) + # if start_to_end_dist < 1.: + # return False + + # Implementation 2: use pre-computed metainfo based on entire valid trajectory + if element.agent_meta_dict['is_stationary']: + return False + + if pred_near_ego: + # Only keep if the closest distance betwen ego and predicted agent for future steps is under 10 meters. + assert ego_valid, "Assert: it only make sense to use pred_near_ego if we filter samples with valid ego." + shorter_future_len = min(element.agent_future_len, element.neighbor_future_lens_np[element.extras['robot_ind']]) + ego_futures = element.neighbor_futures[element.extras['robot_ind']][:shorter_future_len, :2] + pred_futures = element.agent_future_np[:shorter_future_len, :2] + dists = np.linalg.norm(ego_futures - pred_futures, axis=-1) + min_ego_pred_dist = np.min(dists) + if min_ego_pred_dist > 10.: + return False + if lane_near_ego: + raise NotImplementedError() + return True + return filter_fn + + diff --git a/diffstack/data/scene_batch_extras.py b/diffstack/data/scene_batch_extras.py new file mode 100644 index 0000000..efa2077 --- /dev/null +++ b/diffstack/data/scene_batch_extras.py @@ -0,0 +1,188 @@ +import torch +import numpy as np +from typing import Dict, Iterable, Union, List, Optional + +from trajdata.data_structures.batch_element import SceneBatchElement +from trajdata.data_structures.agent import AgentType + +from diffstack.utils.utils import move_list_element_to_front +from diffstack.data.trajdata_lanes import get_goal_lanes, get_lane_projection_points, LanesList + + +def role_selector(element: SceneBatchElement, pred_agent_types: List[AgentType] = (AgentType.VEHICLE, )): + # Find ego + agent_names = [agent.name for agent in element.agents] + if "ego" in agent_names: + ego_i = next(i for i, name in enumerate(agent_names) if name == "ego") + else: + ego_i = element.agents.index(element.centered_agent) + + + # # Find pred agent that is closest to ego + # dists = [] + # inds = [] + # for n_i in range(len(element.agent_futures)): + # if n_i == ego_i: + # continue + + # # Filter what agents we want to predict + # if element.agent_types_np[n_i] not in pred_agent_types: + # continue + + # # Filter incomplete future + # if element.agent_future_lens_np[n_i] < element.agent_future_lens_np[ego_i]: + # continue + + # # Filter parked vehicles + # if element.agent_meta_dicts[n_i]['is_stationary']: + # continue + + # # Distance from predicted agent + # dist = np.square(element.agent_futures[ego_i][:, :2] - element.agent_futures[n_i][:, :2]) + # dist = np.min(dist.sum(axis=-1), axis=-1) # sum over states, min over time + # inds.append(n_i) + # dists.append(dist) + + # if dists: + # pred_i = inds[np.argmin(np.array(dists))] # neighbor that gets closest to current node + # else: + # # No neighbors or all futures are incomplete + # ego_i = -1 + # pred_i = -1 + + element.extras['robot_ind'] = ego_i + # element.extras['pred_agent_ind'] = pred_i + return element + + +def make_robot_the_first(element: SceneBatchElement, extras_key: str = "robot_ind"): + """Reorder neighbors such that the first neighbor is the proxy robot agent for planning.""" + ind = element.extras[extras_key] + if ind < 0: + # No robot, do nothing + pass + else: + element.agent_futures = move_list_element_to_front(element.agent_futures, ind) + element.agent_future_lens_np = np.array(move_list_element_to_front(element.agent_future_lens_np, ind)) + element.agent_future_extents = move_list_element_to_front(element.agent_future_extents, ind) + element.agent_histories = move_list_element_to_front(element.agent_histories, ind) + element.agent_history_lens_np = np.array(move_list_element_to_front(element.agent_history_lens_np, ind)) + element.agent_history_extents = move_list_element_to_front(element.agent_history_extents, ind) + element.agent_types_np = np.array(move_list_element_to_front(element.agent_types_np, ind)) + element.agent_meta_dicts = move_list_element_to_front(element.agent_meta_dicts, ind) + element.extras[extras_key] = 0 + return element + + +def remove_parked(element: SceneBatchElement, keep_agent_ind: Optional[int] = None): + is_parked = np.array([meta_dict['is_stationary'] for meta_dict in element.agent_meta_dicts]) + if keep_agent_ind is not None and keep_agent_ind >= 0: + is_parked[keep_agent_ind] = False + get_filtered_list = lambda x: [x[i] for i in range(element.num_agents) if not is_parked[i]] + get_filtered_np = lambda x: x[np.logical_not(is_parked)] + + element.agent_histories = get_filtered_list(element.agent_histories) + element.agent_history_extents = get_filtered_list(element.agent_history_extents) + element.agent_history_lens_np = get_filtered_np(element.agent_history_lens_np) + + element.agent_futures = get_filtered_list(element.agent_futures) + element.agent_future_extents = get_filtered_list(element.agent_future_extents) + element.agent_future_lens_np = get_filtered_np(element.agent_future_lens_np) + + element.agent_meta_dicts = get_filtered_list(element.agent_meta_dicts) + element.agent_types_np = get_filtered_np(element.agent_types_np) + + if element.map_patches is not None: + element.map_patches = get_filtered_list(element.map_patches) + + element.num_agents = len(element.agent_types_np) + return element + + +def augment_with_point_goal(element: SceneBatchElement): + """Add goal information for planning.""" + + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy goal info + goal = np.full((8, ), np.nan, dtype=np.float32) + else: + # Goal is the gt state at the end of the planning horizon. + goal = element.agent_futures[robot_ind][-1].astype(np.float32) + element.extras["goal"] = goal + return element + + + +def augment_with_goal_lanes(element: SceneBatchElement, goal_to_lane_range: float = 20., max_lateral_dist: float = 4.5, max_heading_delta: float = np.pi/4): + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy lane info + goal_lanes = LanesList([]) + else: + goal_lanes = get_goal_lanes( + element.vec_map, element.extras['goal'], element.centered_agent_from_world_tf, + goal_to_lane_range=goal_to_lane_range, max_lateral_dist=max_lateral_dist, max_heading_delta=max_heading_delta) + + element.extras["lanes_near_goal"] = goal_lanes + return element + + +def augment_with_lanes(element: SceneBatchElement, make_missing_lane_invalid: bool = True): + """Add lane information for planning.""" + if element.num_agents==0: + return element + robot_ind = element.extras['robot_ind'] + if robot_ind < 0: + # No robot, create dummy lane info + lane_projection_points = None + else: + lane_projection_points = get_lane_projection_points( + element.vec_map, + element.agent_histories[robot_ind], element.agent_futures[robot_ind], + element.centered_agent_from_world_tf) + if lane_projection_points is None: + max_future_len = element.agent_future_lens_np.max() + lane_projection_points = np.full((max_future_len + 1, 3), np.nan, dtype=np.float32) + if make_missing_lane_invalid: + # We set robot_idx to -1 to indicate that the sample is invalid + element.extras['robot_ind'] = -1 + + element.extras["lane_projection_points"] = lane_projection_points + + return element + + +def get_filter_func(ego_valid=False, pred_near_ego=False, lane_near_ego=False, pred_not_parked=False): + """ + """ + # shortcut no filtering + if not ego_valid and not pred_near_ego and not lane_near_ego and not pred_not_parked: + return None + + def filter_fn(element: SceneBatchElement) -> bool: + robot_ind = element.extras['robot_ind'] + pred_agent_ind = element.extras['pred_agent_ind'] + + if ego_valid and (robot_ind < 0 or pred_agent_ind < 0): + return False + + if pred_not_parked: + # Implementation 2: use pre-computed metainfo based on entire valid trajectory + if element.agent_meta_dicts[pred_agent_ind]['is_stationary']: + return False + + if pred_near_ego: + # Only keep if the closest distance betwen ego and predicted agent for future steps is under 10 meters. + assert ego_valid, "Assert: it only make sense to use pred_near_ego if we filter samples with valid ego." + shorter_future_len = min(element.agent_future_lens_np[pred_agent_ind], element.agent_future_lens_np[robot_ind]) + ego_futures = element.agent_futures[robot_ind][:shorter_future_len, :2] + pred_futures = element.agent_futures[pred_agent_ind][:shorter_future_len, :2] + dists = np.linalg.norm(ego_futures - pred_futures, axis=-1) + min_ego_pred_dist = np.min(dists) + if min_ego_pred_dist > 10.: + return False + if lane_near_ego: + raise NotImplementedError() + return True + return filter_fn diff --git a/diffstack/data/trajdata_datamodules.py b/diffstack/data/trajdata_datamodules.py new file mode 100644 index 0000000..009583f --- /dev/null +++ b/diffstack/data/trajdata_datamodules.py @@ -0,0 +1,229 @@ +import os +import numpy as np +from collections import defaultdict +from torch.utils.data import Dataset +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from diffstack.configs.base import TrainConfig +from diffstack.data import scene_batch_extras, agent_batch_extras + +from trajdata import AgentBatch, AgentType, UnifiedDataset + + +class UnifiedDataModule(pl.LightningDataModule): + def __init__(self, data_config, train_config: TrainConfig): + super(UnifiedDataModule, self).__init__() + self._data_config = data_config + self._train_config = train_config + self.train_dataset = None + self.valid_dataset = None + self.test_dataset = None + self.train_batch_size = self._train_config.training.batch_size + self.val_batch_size = self._train_config.validation.batch_size + try: + self.test_batch_size = self._train_config.test["batch_size"] + self.test_num_workers = self._train_config.validation.num_data_workers + except: + self.test_batch_size = self.val_batch_size + self.test_num_workers = self._train_config.validation.num_data_workers + self._is_setup = False + + def setup(self, stage=None): + if self._is_setup: + return + data_cfg = self._data_config + future_sec = ( + data_cfg.future_num_frames * data_cfg.step_time + ) # python float precision bug + history_sec = ( + data_cfg.history_num_frames * data_cfg.step_time + ) # python float precision bug + neighbor_distance = data_cfg.max_agents_distance + attention_radius = defaultdict( + lambda: 20.0 + ) # Default range is 20m unless otherwise specified. + attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 + attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 + + if self._data_config.centric == "scene": + # Use actual ego as our planning agent, pick the closest other agent to predict + # for agent-centric prediction models (like T++). + # We only consider vehicles for prediction agent for now. + # Scene-centric prediction models can do prediction for all agents. + pre_filter_transforms = ( + # lambda el: scene_batch_extras.remove_parked(el, keep_agent_ind=0), + lambda el: scene_batch_extras.role_selector( + el, pred_agent_types=[AgentType.VEHICLE] + ), + ) + if self._data_config.get("remove_parked", False): + pre_filter_transforms = pre_filter_transforms + ( + lambda el: scene_batch_extras.remove_parked(el, keep_agent_ind=0), + ) + transforms = pre_filter_transforms + ( + lambda el: scene_batch_extras.make_robot_the_first(el), + # lambda el: scene_batch_extras.augment_with_point_goal(el), + # lambda el: scene_batch_extras.augment_with_lanes(el, make_missing_lane_invalid=True), + # lambda el: scene_batch_extras.augment_with_goal_lanes(el), + ) + get_filter_func = scene_batch_extras.get_filter_func + else: + pre_filter_transforms = ( + lambda el: agent_batch_extras.remove_parked(el), + lambda el: agent_batch_extras.robot_selector(el), + ) + transforms = pre_filter_transforms + ( + lambda el: agent_batch_extras.make_robot_the_first(el), + lambda el: agent_batch_extras.augment_with_point_goal(el), + lambda el: agent_batch_extras.augment_with_lanes( + el, make_missing_lane_invalid=True + ), + lambda el: agent_batch_extras.augment_with_goal_lanes(el), + ) + get_filter_func = agent_batch_extras.get_filter_func + try: + cache_dir = data_cfg.get("cache_dir", os.environ["TRAJDATA_CACHE_DIR"]) + except: + cache_dir = "~/.unified_avdata_cache" + kwargs = dict( + centric=data_cfg.centric, + desired_data=[data_cfg.trajdata_source_train], + desired_dt=data_cfg.step_time, + future_sec=(future_sec, future_sec), + history_sec=(history_sec, history_sec), + data_dirs={ + data_cfg.trajdata_source_root: data_cfg.dataset_path, + }, + # only_types=[AgentType.VEHICLE,AgentType.PEDESTRIAN], + agent_interaction_distances=defaultdict(lambda: neighbor_distance), + incl_raster_map=data_cfg.get("incl_raster_map", False), + raster_map_params={ + "px_per_m": int(1 / data_cfg.pixel_size), + "map_size_px": data_cfg.raster_size, + "return_rgb": False, + "offset_frac_xy": data_cfg.raster_center, + "original_format": True, + }, + state_format="x,y,xd,yd,xdd,ydd,s,c", + obs_format="x,y,xd,yd,xdd,ydd,s,c", + incl_vector_map=data_cfg.get("incl_vector_map", False), + verbose=True, + max_agent_num=1 + data_cfg.other_agents_num, + # max_neighbor_num = data_cfg.other_agents_num, + num_workers=min(os.cpu_count(), 64), + # num_workers = 0, + ego_only=self._train_config.ego_only, + transforms=transforms, + rebuild_cache=self._train_config.rebuild_cache, + cache_location=cache_dir, + save_index=False, + ) + if kwargs["incl_vector_map"]: + kwargs["vector_map_params"] = { + "incl_road_lanes": True, + "incl_road_areas": False, + "incl_ped_crosswalks": False, + "incl_ped_walkways": False, + "max_num_lanes": data_cfg.max_num_lanes, + "num_lane_pts": data_cfg.num_lane_pts, + "remove_single_successor": data_cfg.remove_single_successor, + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "collate": True, + "calc_lane_graph": data_cfg.get("calc_lane_graph", False), + } + if "waymo" in data_cfg.trajdata_source_root: + kwargs["vector_map_params"]["keep_in_memory"] = False + kwargs["vector_map_params"]["radius"] = 300 + print(kwargs) + self.train_dataset = UnifiedDataset(**kwargs) + + # prepare validation dataset + kwargs["desired_data"] = [data_cfg.trajdata_source_valid] + if data_cfg.trajdata_val_source_root is not None: + kwargs["data_dirs"] = { + data_cfg.trajdata_val_source_root: data_cfg.dataset_path, + } + self.valid_dataset = UnifiedDataset(**kwargs) + + # prepare test dataset if specified + if data_cfg.trajdata_source_test is not None: + kwargs["desired_data"] = [data_cfg.trajdata_source_test] + if data_cfg.trajdata_test_source_root is not None: + kwargs["data_dirs"] = { + data_cfg.trajdata_test_source_root: data_cfg.dataset_path, + } + self.test_dataset = UnifiedDataset(**kwargs) + self._is_setup = True + + def train_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.train_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + + return DataLoader( + dataset=self.train_dataset, + shuffle=True, + batch_size=self.train_batch_size, + num_workers=self._train_config.training.num_data_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True + if self._train_config.training.num_data_workers + else False, + ) + + def val_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.valid_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + + return DataLoader( + dataset=self.valid_dataset, + shuffle=False, + batch_size=self.val_batch_size, + num_workers=self._train_config.validation.num_data_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True + if self._train_config.validation.num_data_workers > 0 + else False, + ) + + def test_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.test_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + return ( + self.val_dataloader() + if self.test_dataset is None + else DataLoader( + dataset=self.test_dataset, + shuffle=False, + batch_size=self.test_batch_size, + num_workers=self.test_num_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True if self.test_num_workers > 0 else False, + ) + ) + + def predict_dataloader(self): + return self.val_dataloader diff --git a/diffstack/data/trajdata_lanes.py b/diffstack/data/trajdata_lanes.py new file mode 100644 index 0000000..01bbdbb --- /dev/null +++ b/diffstack/data/trajdata_lanes.py @@ -0,0 +1,165 @@ +import torch +import numpy as np + +from typing import Dict, Iterable +from time import time +from trajdata.caching.scene_cache import SceneCache + +from trajdata.utils.arr_utils import angle_wrap +from trajdata.utils.map_utils import get_polyline_headings +from trajdata.data_structures.collation import CustomCollateData +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.maps.vec_map import VectorMap, RoadLane, Polyline + +# from trajdata.maps.map_kdtree import LaneCenterKDTree +from trajdata.utils.arr_utils import ( + batch_nd_transform_points_np, + batch_nd_transform_angles_np, + batch_nd_transform_points_angles_np, + angle_wrap, +) +from diffstack.utils.geometry_utils import batch_proj +from diffstack.utils.utils import convert_state_pred2plan, lane_frenet_features_simple + + +# # For nuscenes-based implementation +# from nuscenes.map_expansion.map_api import NuScenesMap, locations as nusc_map_names +# from diffstack.data.manual_preprocess_plan_data import get_lane_reference_points + + +class LanesList(list, CustomCollateData): + @staticmethod + def __collate__(elements: list) -> any: + return LanesList( + [[torch.as_tensor(pts) for pts in lanelist] for lanelist in elements] + ) + + def __to__(self, device, non_blocking=False): + # Always keep on cpu + del device + del non_blocking + return self + + +def get_lane_reference_points_from_polylines( + global_xyh: np.ndarray, + lane_polylines_xyh: Iterable[np.ndarray], + max_num_points: int = 16, + resolution_meters: float = 1.0, +): + # Previosly we did an interpolation of 1m. This was important for cost functions based on N ref points, + # but its irrelevant for cost functions using projection of the expert state onto the lane segment. + # The projection will remain the same regardless of interpolated points. + # The only difference would be in headings, which are not interpolated this way. One solution is to do + # heading interpolation after projection. + + if not lane_polylines_xyh: + return np.zeros((global_xyh.shape[0], 0, 3), dtype=np.float32) + + # Code v5: keep top-n closest lane points, deprioritize lane points with over 45 heading difference + pts_xyh = np.concatenate(lane_polylines_xyh, axis=0) + + # Filter distance, set2set [traj, lanept] + d2mat = np.square(pts_xyh[None, :, :2] - global_xyh[:, None, :2]).sum(-1) + heading_diff_mat = pts_xyh[None, :, 2] - global_xyh[:, None, 2] + heading_diff_mat = np.minimum( + np.abs(angle_wrap(heading_diff_mat)), # same direction + np.abs(angle_wrap(heading_diff_mat + np.pi)), + ) # opposite direction + # Add a large constant if heading differs by over 45 degrees + d2mat += (heading_diff_mat > np.pi / 4).astype(d2mat.dtype) * 1e6 + closest_ind = np.argsort(d2mat, axis=-1)[:, :max_num_points] + pts_xyh = pts_xyh[closest_ind] # (traj, max_num_points, 4) + + return pts_xyh + + +def get_lane_projection_points( + vec_map: VectorMap, + ego_histories: np.ndarray, + ego_futures: np.ndarray, + agent_from_world_tf: np.ndarray, +) -> np.ndarray: + """Points along a lane that are closest to GT future ego states.""" + + if vec_map is None: + return None + # Get present and future for the ego agent. + ego_pres_future = np.concatenate([ego_histories[-1:], ego_futures], axis=0) + ego_xyhv = convert_state_pred2plan(ego_pres_future) + world_from_agent_tf = np.linalg.inv(agent_from_world_tf) + ego_xyh_world = batch_nd_transform_points_angles_np( + ego_xyhv[:, :3], world_from_agent_tf + ) + + # trajlen = ego_xyh_world.shape[0] + # if trajlen == 0: + # lane_points_xyh_world = [] + # else: + ego_xyz_world = np.concatenate( + (ego_xyh_world[:, :2], np.zeros_like(ego_xyh_world[:, :1])), axis=-1 + ) + closest_lanes = vec_map.get_closest_unique_lanes(ego_xyz_world) + lane_polylines_xyh = [lane.center.points[:, (0, 1, 3)] for lane in closest_lanes] + lane_points_xyh_world = get_lane_reference_points_from_polylines( + ego_xyh_world, lane_polylines_xyh, max_num_points=1 + ) + + if lane_points_xyh_world.shape[1] == 0: + # No lanes. + return None + else: + # World to agent coordinates + lane_points_xyh = batch_nd_transform_points_angles_np( + lane_points_xyh_world[..., :3], agent_from_world_tf[None] + ) + + lane_projection_points = lane_frenet_features_simple( + ego_xyhv[..., :3], lane_points_xyh[:, :] + ) + + assert lane_projection_points.shape == (ego_futures.shape[0] + 1, 3) + + return lane_projection_points.astype(np.float32) + + +def get_goal_lanes( + vec_map: VectorMap, + goal_xyvvaahh: np.ndarray, + agent_from_world_tf: np.ndarray, + goal_to_lane_range: float = 20.0, + max_lateral_dist: float = 4.5, + max_heading_delta: float = np.pi / 4, +): + assert goal_xyvvaahh.shape[-1] == 8 # xyvvaahh + + goal_xyhv = convert_state_pred2plan(goal_xyvvaahh) + world_from_agent_tf = np.linalg.inv(agent_from_world_tf) + goal_xyh_world = batch_nd_transform_points_angles_np( + goal_xyhv[:3], world_from_agent_tf + ) + + # Find lanes in range + goal_xyz_world = np.concatenate( + (goal_xyh_world[:2], np.zeros_like(goal_xyh_world[:1])), axis=-1 + ) + near_lanes = vec_map.get_lanes_within(goal_xyz_world, dist=goal_to_lane_range) + near_lanes_xyh_world = [lane.center.points[:, (0, 1, 3)] for lane in near_lanes] + + # Filter + lanes_xyh_world = [] + for lane_xyh in near_lanes_xyh_world: + delta_x, delta_y, dpsi = batch_proj(goal_xyh_world, lane_xyh) + if ( + abs(dpsi[0]) < max_heading_delta + and np.min(np.abs(delta_y)) < max_lateral_dist + ): + lanes_xyh_world.append(lane_xyh) + + # World to agent coordinates + goal_lanes = [ + batch_nd_transform_points_angles_np(lane_xyh, agent_from_world_tf[None]) + for lane_xyh in lanes_xyh_world + ] + + return LanesList(goal_lanes) diff --git a/diffstack/dynamics/__init__.py b/diffstack/dynamics/__init__.py new file mode 100644 index 0000000..b0288a3 --- /dev/null +++ b/diffstack/dynamics/__init__.py @@ -0,0 +1,21 @@ +from typing import Union + +from diffstack.dynamics.single_integrator import SingleIntegrator +from diffstack.dynamics.unicycle import Unicycle +from diffstack.dynamics.bicycle import Bicycle +from diffstack.dynamics.double_integrator import DoubleIntegrator +from diffstack.dynamics.base import Dynamics, DynType +from diffstack.dynamics.unicycle import Unicycle + + +def get_dynamics_model(dyn_type: Union[str, DynType]): + if dyn_type in ["Unicycle", DynType.UNICYCLE]: + return Unicycle + elif dyn_type == ["SingleIntegrator", DynType.SI]: + return SingleIntegrator + elif dyn_type == ["DoubleIntegrator", DynType.DI]: + return DoubleIntegrator + else: + raise NotImplementedError( + "Dynamics model {} is not implemented".format(dyn_type) + ) diff --git a/diffstack/dynamics/base.py b/diffstack/dynamics/base.py new file mode 100644 index 0000000..b43e32e --- /dev/null +++ b/diffstack/dynamics/base.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +import math, copy, time +import abc +from copy import deepcopy + + +class DynType: + """ + Holds environment types - one per environment class. + These act as identifiers for different environments. + """ + + UNICYCLE = 1 + SI = 2 + DI = 3 + BICYCLE = 4 + DDI = 5 + + +class Dynamics(abc.ABC): + @abc.abstractmethod + def __init__(self, dt, name, **kwargs): + self.dt = dt + self._name = name + self.xdim = 4 + self.udim = 2 + + @abc.abstractmethod + def __call__(self, x, u): + return + + @abc.abstractmethod + def step(self, x, u, dt, bound=True): + return + + def name(self): + return self._name + + @abc.abstractmethod + def type(self): + return + + @abc.abstractmethod + def ubound(self, x): + return + + @staticmethod + def state2pos(x): + return + + @staticmethod + def state2yaw(x): + return + + @staticmethod + def get_state(pos,yaw,dt,mask): + return + + def get_input_violation(self,x,u): + lb, ub = self.ubound(x) + return torch.maximum((lb-u).clip(min=0.0), (u-ub).clip(min=0.0)) + + + def forward_dynamics(self,x0: torch.Tensor,u: torch.Tensor, include_step0: bool = False): + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + u (Torch.tensor): action tensor of size [B, (A), T, 2] + include_step0 (bool): the output trajectory will include the current state if true. + Returns: + state tensor of size [B, (A), T, 4] + """ + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :]) + + if include_step0: + x = torch.stack(x, dim=-2) + else: + x = torch.stack(x[1:], dim=-2) + return x + + + diff --git a/diffstack/dynamics/bicycle.py b/diffstack/dynamics/bicycle.py new file mode 100644 index 0000000..3590f37 --- /dev/null +++ b/diffstack/dynamics/bicycle.py @@ -0,0 +1,153 @@ +import torch +import math + +from diffstack.dynamics.base import Dynamics, DynType + + +def bicycle_model(state, acc, ddh, vehicle_length, dt, max_hdot=math.pi * 2.0, max_s=50.0): + """ + Simple differentiable bicycle model that does not allow reverse + Args: + state (torch.Tensor): a batch of current kinematic state [B, ..., 5] (x, y, yaw, speed, hdot) + acc (torch.Tensor): a batch of acceleration profile [B, ...] (acc) + ddh (torch.Tensor): a batch of heading acceleration profile [B, ...] (heading) + vehicle_length (torch.Tensor): a batch of vehicle length [B, ...] (length) + dt (float): time between steps + max_hdot (float): maximum change of heading (rad/s) + max_s (float): maximum speed (m/s) + + Returns: + New kinematic state (torch.Tensor) + """ + # state: (x, y, h, speed, hdot) + assert state.shape[-1] == 5 + newhdot = (state[..., 4] + ddh * dt).clamp(-max_hdot, max_hdot) + newh = state[..., 2] + dt * state[..., 3].abs() / vehicle_length * newhdot + news = (state[..., 3] + acc * dt).clamp(0.0, max_s) # no reverse + newy = state[..., 1] + news * newh.sin() * dt + newx = state[..., 0] + news * newh.cos() * dt + + newstate = torch.empty_like(state) + newstate[..., 0] = newx + newstate[..., 1] = newy + newstate[..., 2] = newh + newstate[..., 3] = news + newstate[..., 4] = newhdot + + return newstate + + +class Bicycle(Dynamics): + + def __init__( + self, + dt, + acc_bound=(-10, 8), + ddh_bound=(-math.pi * 2.0, math.pi * 2.0), + max_speed=50.0, + max_hdot=math.pi * 2.0 + ): + """ + A simple bicycle dynamics model + Args: + acc_bound (tuple): acceleration bound (m/s^2) + ddh_bound (tuple): angular acceleration bound (rad/s^2) + max_speed (float): maximum speed, must be positive + max_hdot (float): maximum turning speed, must be positive + """ + super(Bicycle, self).__init__(name="bicycle") + self.xdim = 6 + self.udim = 2 + self.dt = dt + assert max_speed >= 0 + assert max_hdot >= 0 + self.acc_bound = acc_bound + self.ddh_bound = ddh_bound + self.max_speed = max_speed + self.max_hdot = max_hdot + + def get_normalized_controls(self, u): + u = torch.sigmoid(u) # normalize to [0, 1] + acc = self.acc_bound[0] + (self.acc_bound[1] - self.acc_bound[0]) * u[..., 0] + ddh = self.ddh_bound[0] + (self.ddh_bound[1] - self.ddh_bound[0]) * u[..., 1] + return acc, ddh + + def get_clipped_controls(self, u): + acc = torch.clip(u[..., 0], self.acc_bound[0], self.acc_bound[1]) + ddh = torch.clip(u[..., 1], self.ddh_bound[0], self.ddh_bound[1]) + return acc, ddh + + def step(self, x, u, normalize=True): + """ + Take a step with the dynamics model + Args: + x (torch.Tensor): current state [B, ..., 6] (x, y, h, speed, dh, veh_length) + u (torch.Tensor): (un-normalized) actions [B, ..., 2] (acc, ddh) + dt (float): time between steps + normalize (bool): whether to normalize the actions + + Returns: + next_x (torch.Tensor): next state after taking the action + """ + assert x.shape[-1] == self.xdim + assert u.shape[:-1] == x.shape[:-1] + assert u.shape[-1] == self.udim + if normalize: + acc, ddh = self.get_normalized_controls(u) + else: + acc, ddh = self.get_clipped_controls(u) + next_x = x.clone() # keep the extent the same + next_x[..., :5] = bicycle_model( + state=x[..., :5], + acc=acc, + ddh=ddh, + vehicle_length=x[..., 5], + dt=self.dt, + max_hdot=self.max_hdot, + max_s=self.max_speed + ) + return next_x + + def calculate_vel(self,pos, yaw, mask): + + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / self.dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / self.dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + + return vel + + def type(self): + return DynType.BICYCLE + + def state2pos(self, x): + return x[..., :2] + + def state2yaw(self, x): + return x[..., 2:3] + + def __call__(self, x, u): + pass + + def ubound(self, x): + pass + + def name(self): + return self._name diff --git a/diffstack/dynamics/double_integrator.py b/diffstack/dynamics/double_integrator.py new file mode 100644 index 0000000..3e13630 --- /dev/null +++ b/diffstack/dynamics/double_integrator.py @@ -0,0 +1,202 @@ +from diffstack.dynamics.base import DynType, Dynamics +from diffstack.utils.math_utils import soft_sat +import torch +import numpy as np + + +class DoubleIntegrator(Dynamics): + def __init__(self, dt, name="DI_model", abound=None, vbound=None): + self._name = name + self._type = DynType.DI + self.xdim = 4 + self.udim = 2 + self.dt = dt + self.cyclic_state = list() + self.vbound = vbound + self.abound = abound + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + if isinstance(x, np.ndarray): + return np.hstack((x[..., 2:], u)) + elif isinstance(x, torch.Tensor): + return torch.cat((x[..., 2:], u), dim=-1) + else: + raise NotImplementedError + + def step(self, x, u, bound=True, return_jacobian=False): + if self.abound is None: + bound = False + + if isinstance(x, np.ndarray): + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + xn = np.hstack( + ( + (x[..., 2:4] + 0.5 * u * self.dt) * self.dt + x[..., 0:2], + x[..., 2:4] + u * self.dt, + ) + ) + if return_jacobian: + raise NotImplementedError + else: + return xn + elif isinstance(x, torch.Tensor): + if bound: + lb, ub = self.ubound(x) + u = torch.clip(u, min=lb, max=ub) + xn = torch.clone(x) + xn[..., 0:2] += (x[..., 2:4] + 0.5 * u * self.dt) * self.dt + xn[..., 2:4] += u * self.dt + if return_jacobian: + jacx = torch.cat( + [ + torch.zeros([4, 2]), + torch.cat([torch.eye(2) * self.dt, torch.zeros([2, 2])], 0), + ], + -1, + ) + torch.eye(4) + jacx = torch.tile(jacx, (*x.shape[:-1], 1, 1)).to(x.device) + + jacu = torch.cat( + [torch.eye(2) * self.dt**2 * 2, torch.eye(2) * self.dt], 0 + ) + jacu = torch.tile(jacu, (*x.shape[:-1], 1, 1)).to(x.device) + return xn, jacx, jacu + else: + return xn + else: + raise NotImplementedError + + def get_x_Gaussian_from_u(self, x, mu_u, var_u): + mu_x, _, jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + var_u_mat = torch.diag_embed(var_u) + var_x = torch.matmul(torch.matmul(jacu, var_u_mat), jacu.transpose(-1, -2)) + return mu_x, var_x + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if self.vbound is None: + if isinstance(x, np.ndarray): + lb = np.ones_like(x[..., 2:]) * self.abound[0] + ub = np.ones_like(x[..., 2:]) * self.abound[1] + + elif isinstance(x, torch.Tensor): + lb = torch.ones_like(x[..., 2:]) * torch.from_numpy( + self.abound[:, 0] + ).to(x.device) + ub = torch.ones_like(x[..., 2:]) * torch.from_numpy( + self.abound[:, 1] + ).to(x.device) + + else: + raise NotImplementedError + else: + if isinstance(x, np.ndarray): + lb = (x[..., 2:] > self.vbound[0]) * self.abound[0] + ub = (x[..., 2:] < self.vbound[1]) * self.abound[1] + + elif isinstance(x, torch.Tensor): + lb = ( + x[..., 2:] > torch.from_numpy(self.vbound[0]).to(x.device) + ) * torch.from_numpy(self.abound[0]).to(x.device) + ub = ( + x[..., 2:] < torch.from_numpy(self.vbound[1]).to(x.device) + ) * torch.from_numpy(self.abound[1]).to(x.device) + else: + raise NotImplementedError + return lb, ub + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2yaw(x): + # return torch.atan2(x[..., 3:], x[..., 2:3]) + return torch.zeros_like(x[..., 0:1]) + + def inverse_dyn(self, x, xp): + return (xp[..., 2:] - x[..., 2:]) / self.dt + + def calculate_vel(self, pos, yaw, mask): + vel = (pos[..., 1:, :] - pos[..., :-1, :]) / self.dt + if isinstance(pos, torch.Tensor): + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r, -1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r), -1) * vel_l + + np.expand_dims(mask_r & (~mask_l), -1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def get_state(self, pos, yaw, dt, mask): + vel = self.calculate_vel(pos, yaw, mask) + if isinstance(vel, np.ndarray): + return np.concatenate((pos, vel), -1) + elif isinstance(vel, torch.Tensor): + return torch.cat((pos, vel), -1) + + def forward_dynamics( + self, + x0: torch.Tensor, + u: torch.Tensor, + include_step0: bool = False, + ): + if include_step0: + raise NotImplementedError + + if isinstance(u, np.ndarray): + u = np.clip(u, self.abound[0], self.abound[1]) + delta_v = np.cumsum(u * self.dt, -2) + vel = x0[..., np.newaxis, 2:] + delta_v + vel = np.clip(vel, self.vbound[0], self.vbound[1]) + delta_xy = np.cumsum(vel * self.dt, -2) + xy = x0[..., np.newaxis, :2] + delta_xy + + traj = np.concatenate((xy, vel), -1) + elif isinstance(u, torch.Tensor): + u = soft_sat(u, self.abound[0], self.abound[1]) + delta_v = torch.cumsum(u * self.dt, -2) + vel = x0[..., 2:].unsqueeze(-2) + delta_v + vel = soft_sat(vel, self.vbound[0], self.vbound[1]) + delta_xy = torch.cumsum(vel * self.dt, -2) + xy = x0[..., :2].unsqueeze(-2) + delta_xy + + traj = torch.cat((xy, vel), -1) + return traj diff --git a/diffstack/dynamics/single_integrator.py b/diffstack/dynamics/single_integrator.py new file mode 100644 index 0000000..ad2b2ea --- /dev/null +++ b/diffstack/dynamics/single_integrator.py @@ -0,0 +1,52 @@ +from diffstack.dynamics.base import DynType, Dynamics +import torch +import numpy as np + + +class SingleIntegrator(Dynamics): + def __init__(self, dt, name, vbound): + self._name = name + self.dt = dt + self._type = DynType.SI + self.xdim = vbound.shape[0] + self.udim = vbound.shape[0] + self.cyclic_state = list() + self.vbound = np.array(vbound) + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + + return u + + def step(self, x, u, bound=True): + assert x.shape[:-1] == u.shape[:, -1] + if bound: + lb, ub = self.ubound(x) + if isinstance(x, np.ndarray): + u = np.clip(u, lb, ub) + elif isinstance(x, torch.Tensor): + u = torch.clip(u, min=lb, max=ub) + + return x + u * self.dt + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + lb = np.ones_like(x) * self.vbound[:, 0] + ub = np.ones_like(x) * self.vbound[:, 1] + return lb, ub + elif isinstance(x, torch.Tensor): + lb = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 0]) + ub = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 1]) + return lb, ub + else: + raise NotImplementedError + + @staticmethod + def state2pos(x): + return x[..., 0:2] diff --git a/diffstack/dynamics/unicycle.py b/diffstack/dynamics/unicycle.py new file mode 100644 index 0000000..9bc5e36 --- /dev/null +++ b/diffstack/dynamics/unicycle.py @@ -0,0 +1,834 @@ +from diffstack.dynamics.base import DynType, Dynamics +from diffstack.utils.math_utils import soft_sat +import torch +import torch.nn as nn +import numpy as np +from copy import deepcopy +from torch.autograd.functional import jacobian +import diffstack.utils.geometry_utils as GeoUtils + + +class Unicycle(Dynamics): + def __init__( + self, + dt, + name=None, + max_steer=0.5, + max_yawvel=8, + acce_bound=[-6, 4], + vbound=[-10, 30], + ): + self.dt = dt + self._name = name + self._type = DynType.UNICYCLE + self.xdim = 4 + self.udim = 2 + self.cyclic_state = [3] + self.acce_bound = acce_bound + self.vbound = vbound + self.max_steer = max_steer + self.max_yawvel = max_yawvel + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + theta = x[..., 3:4] + dxdt = np.hstack( + (np.cos(theta) * x[..., 2:3], np.sin(theta) * x[..., 2:3], u) + ) + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + theta = x[..., 3:4] + dxdt = torch.cat( + (torch.cos(theta) * x[..., 2:3], torch.sin(theta) * x[..., 2:3], u), + dim=-1, + ) + else: + raise NotImplementedError + return dxdt + + def step(self, x, u, bound=True, return_jacobian=False): + assert x.shape[:-1] == u.shape[:-1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + + theta = x[..., 3:4] + cos_theta_p = np.cos(theta) - 0.5 * np.sin(theta) * u[..., 1:] * self.dt + sin_theta_p = np.sin(theta) + 0.5 * np.cos(theta) * u[..., 1:] * self.dt + vel_p = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + + dxdt = np.hstack( + ( + cos_theta_p * vel_p, + sin_theta_p * vel_p, + u, + ) + ) + xp = x + dxdt * self.dt + if return_jacobian: + d_cos_theta_p_d_theta = ( + -np.sin(theta) - 0.5 * np.cos(theta) * u[..., 1:] * self.dt + ) + d_sin_theta_p_d_theta = ( + np.cos(theta) - 0.5 * np.sin(theta) * u[..., 1:] * self.dt + ) + d_vel_p_d_a = 0.5 * self.dt + d_cos_theta_p_d_yaw = -0.5 * np.sin(theta) * self.dt + d_sin_theta_p_d_yaw = 0.5 * np.cos(theta) * self.dt + jacx = np.tile(np.eye(4), (*x.shape[:-1], 1, 1)) + jacx[..., 0, 2:3] = cos_theta_p * self.dt + jacx[..., 0, 3:4] = vel_p * self.dt * d_cos_theta_p_d_theta + jacx[..., 1, 2:3] = sin_theta_p * self.dt + jacx[..., 1, 3:4] = vel_p * self.dt * d_sin_theta_p_d_theta + + jacu = np.zeros((*x.shape[:-1], 4, 2)) + jacu[..., 0, 0:1] = cos_theta_p * self.dt * d_vel_p_d_a + jacu[..., 0, 1:2] = vel_p * self.dt * d_cos_theta_p_d_yaw + jacu[..., 1, 0:1] = sin_theta_p * self.dt * d_vel_p_d_a + jacu[..., 1, 1:2] = vel_p * self.dt * d_sin_theta_p_d_yaw + jacu[..., 2, 0:1] = self.dt + jacu[..., 3, 1:2] = self.dt + + return xp, jacx, jacu + else: + return xp + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + if bound: + lb, ub = self.ubound(x) + # s = (u - lb) / torch.clip(ub - lb, min=1e-3) + # u = lb + (ub - lb) * torch.sigmoid(s) + u = torch.clip(u, lb, ub) + + theta = x[..., 3:4] + cos_theta_p = ( + torch.cos(theta) - 0.5 * torch.sin(theta) * u[..., 1:] * self.dt + ) + sin_theta_p = ( + torch.sin(theta) + 0.5 * torch.cos(theta) * u[..., 1:] * self.dt + ) + vel_p = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + dxdt = torch.cat( + ( + cos_theta_p * vel_p, + sin_theta_p * vel_p, + u, + ), + dim=-1, + ) + xp = x + dxdt * self.dt + if return_jacobian: + d_cos_theta_p_d_theta = ( + -torch.sin(theta) - 0.5 * torch.cos(theta) * u[..., 1:] * self.dt + ) + d_sin_theta_p_d_theta = ( + torch.cos(theta) - 0.5 * torch.sin(theta) * u[..., 1:] * self.dt + ) + d_vel_p_d_a = 0.5 * self.dt + d_cos_theta_p_d_yaw = -0.5 * torch.sin(theta) * self.dt + d_sin_theta_p_d_yaw = 0.5 * torch.cos(theta) * self.dt + eye4 = torch.tile(torch.eye(4, device=x.device), (*x.shape[:-1], 1, 1)) + jacxy = torch.zeros((*x.shape[:-1], 4, 2), device=x.device) + zeros21 = torch.zeros((*x.shape[:-1], 2, 1), device=x.device) + jacv = ( + torch.cat( + [cos_theta_p.unsqueeze(-2), sin_theta_p.unsqueeze(-2), zeros21], + -2, + ) + * self.dt + ) + jactheta = ( + torch.cat( + [ + (vel_p * d_cos_theta_p_d_theta).unsqueeze(-2), + (vel_p * d_sin_theta_p_d_theta).unsqueeze(-2), + zeros21, + ], + -2, + ) + * self.dt + ) + jacx = torch.cat([jacxy, jacv, jactheta], -1) + eye4 + # jacx = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # jacx[...,0,2:3] = cos_theta_p*self.dt + # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # jacx[...,1,2:3] = sin_theta_p*self.dt + # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + jacxy_a = ( + torch.cat( + [cos_theta_p.unsqueeze(-2), sin_theta_p.unsqueeze(-2)], -2 + ) + * self.dt + * d_vel_p_d_a + ) + jacxy_yaw = ( + torch.cat( + [ + (vel_p * d_cos_theta_p_d_yaw).unsqueeze(-2), + (vel_p * d_sin_theta_p_d_yaw).unsqueeze(-2), + ], + -2, + ) + * self.dt + ) + eye2 = torch.tile(torch.eye(2, device=x.device), (*x.shape[:-1], 1, 1)) + jacu = torch.cat( + [torch.cat([jacxy_a, jacxy_yaw], -1), eye2 * self.dt], -2 + ) + # jacu = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # jacu[...,2,0:1] = self.dt + # jacu[...,3,1:2] = self.dt + return xp, jacx, jacu + else: + return xp + else: + raise NotImplementedError + + def get_x_Gaussian_from_u(self, x, mu_u, var_u): + mu_x, _, jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + var_u_mat = torch.diag_embed(var_u) + var_x = torch.matmul(torch.matmul(jacu, var_u_mat), jacu.transpose(-1, -2)) + return mu_x, var_x + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + v = x[..., 2:3] + vclip = np.clip(np.abs(v), a_min=0.1, a_max=None) + + yawbound = np.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + acce_lb = np.clip( + np.clip(self.vbound[0] - v, None, self.acce_bound[1]), + self.acce_bound[0], + None, + ) + acce_ub = np.clip( + np.clip(self.vbound[1] - v, self.acce_bound[0], None), + None, + self.acce_bound[1], + ) + lb = np.concatenate((acce_lb, -yawbound), -1) + ub = np.concatenate((acce_ub, yawbound), -1) + return lb, ub + elif isinstance(x, torch.Tensor): + v = x[..., 2:3] + vclip = torch.clip(torch.abs(v), min=0.1) + yawbound = torch.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + yawbound = torch.clip(yawbound, min=0.1) + acce_lb = torch.clip( + torch.clip(self.vbound[0] - v, max=self.acce_bound[1]), + min=self.acce_bound[0], + ) + acce_ub = torch.clip( + torch.clip(self.vbound[1] - v, min=self.acce_bound[0]), + max=self.acce_bound[1], + ) + lb = torch.cat((acce_lb, -yawbound), dim=-1) + ub = torch.cat((acce_ub, yawbound), dim=-1) + return lb, ub + + else: + raise NotImplementedError + + def uniform_sample_xp(self, x, num_sample): + if isinstance(x, torch.Tensor): + u_lb, u_ub = self.ubound(x) + u_sample = torch.rand( + *x.shape[:-1], num_sample, self.udim, device=x.device + ) * (u_ub - u_lb).unsqueeze(-2) + u_lb.unsqueeze(-2) + xp = self.step( + x.unsqueeze(-2).repeat_interleave(num_sample, -2), u_sample, bound=False + ) + elif isinstance(x, np.ndarray): + u_lb, u_ub = self.ubound(x) + u_sample = np.random.uniform( + u_lb[..., None, :], + u_ub[..., None, :], + (*x.shape[:-1], num_sample, self.udim), + ) + xp = self.step( + x[..., None, :].repeat(num_sample, -2), u_sample, bound=False + ) + else: + raise NotImplementedError + return xp + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2yaw(x): + return x[..., 3:] + + @staticmethod + def state2vel(x): + return x[..., 2:3] + + @staticmethod + def state2xyvsc(x): + return torch.cat([x[..., :3], torch.sin(x[..., 3:]), torch.cos(x[..., 3:])], -1) + + @staticmethod + def combine_to_state(xy, vel, yaw): + if isinstance(xy, torch.Tensor): + return torch.cat((xy, vel, yaw), -1) + elif isinstance(xy, np.ndarray): + return np.concatenate((xy, vel, yaw), -1) + + def calculate_vel(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + if isinstance(pos, torch.Tensor): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * np.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * np.sin(yaw[..., 1:, :]) + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r, -1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r), -1) * vel_l + + np.expand_dims(mask_r & (~mask_l), -1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def inverse_dyn(self, x, xp, dt=None, mask=None): + if dt is None: + dt = self.dt + dx = torch.cat( + [xp[..., 2:3] - x[..., 2:3], GeoUtils.round_2pi(xp[..., 3:] - x[..., 3:])], + -1, + ) + u = dx / dt + if mask is not None: + u = u * mask[..., None] + return u + + def get_state(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + vel = self.calculate_vel(pos, yaw, mask, dt) + if isinstance(vel, np.ndarray): + return np.concatenate((pos, vel, yaw), -1) + elif isinstance(vel, torch.Tensor): + return torch.cat((pos, vel, yaw), -1) + + @staticmethod + def get_axay(x, u): + yaw = Unicycle.state2yaw(x) + vel = Unicycle.state2vel(x) + acce = u[..., 0:1] + r = u[..., 1:] + stack_fun = torch.stack if isinstance(x, torch.Tensor) else np.stack + sin = torch.sin if isinstance(x, torch.Tensor) else np.sin + cos = torch.cos if isinstance(x, torch.Tensor) else np.cos + return stack_fun( + [ + acce * cos(yaw) - vel * r * sin(yaw), + acce * sin(yaw) + vel * r * cos(yaw), + ], + -1, + ) + + def forward_dynamics( + self, + x0: torch.Tensor, + u: torch.Tensor, + mode="parallel", + bound=True, + ): + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + actions (Torch.tensor): action tensor of size [B, (A), T, 2] + Returns: + state tensor of size [B, (A), T, 4] + """ + if mode == "chain": + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :], bound=bound) + + return torch.stack(x[1:], dim=-2) + + else: + assert mode in ["parallel", "partial_parallel"] + with torch.no_grad(): + num_steps = u.shape[-2] + b = x0.shape[0] + device = x0.device + + mat = torch.ones(num_steps + 1, num_steps + 1, device=device) + mat = torch.tril(mat) + mat = mat.repeat(b, 1, 1) + + mat2 = torch.ones(num_steps, num_steps + 1, device=device) + mat2_h = torch.tril(mat2, diagonal=1) + mat2_l = torch.tril(mat2, diagonal=-1) + mat2 = torch.logical_xor(mat2_h, mat2_l).float() * 0.5 + mat2 = mat2.repeat(b, 1, 1) + if x0.ndim == 3: + mat = mat.unsqueeze(1) + mat2 = mat2.unsqueeze(1) + + acc = u[..., :1] + yaw = u[..., 1:] + if bound: + acc_clipped = soft_sat(acc, self.acce_bound[0], self.acce_bound[1]) + else: + acc_clipped = acc + + if mode == "parallel": + acc_paded = torch.cat( + (x0[..., -2:-1].unsqueeze(-2), acc_clipped * self.dt), dim=-2 + ) + + v_raw = torch.matmul(mat, acc_paded) + v_clipped = soft_sat(v_raw, self.vbound[0], self.vbound[1]) + else: + v_clipped = [x0[..., 2:3]] + [None] * num_steps + for t in range(num_steps): + vt = v_clipped[t] + acc_clipped_t = soft_sat( + acc_clipped[:, t], self.vbound[0] - vt, self.vbound[1] - vt + ) + v_clipped[t + 1] = vt + acc_clipped_t * self.dt + v_clipped = torch.stack(v_clipped, dim=-2) + + v_avg = torch.matmul(mat2, v_clipped) + + v = v_clipped[..., 1:, :] + if bound: + with torch.no_grad(): + v_earlier = v_clipped[..., :-1, :] + + yawbound = torch.minimum( + self.max_steer * torch.abs(v_earlier), + self.max_yawvel / torch.clip(torch.abs(v_earlier), min=0.1), + ) + yawbound_clipped = torch.clip(yawbound, min=0.1) + + yaw_clipped = soft_sat(yaw, -yawbound_clipped, yawbound_clipped) + else: + yaw_clipped = yaw + yawvel_paded = torch.cat( + (x0[..., -1:].unsqueeze(-2), yaw_clipped * self.dt), dim=-2 + ) + yaw_full = torch.matmul(mat, yawvel_paded) + yaw = yaw_full[..., 1:, :] + + # print('before clip', torch.cat((acc[0], yawvel[0]), dim=-1)) + # print('after clip', torch.cat((acc_clipped[0], yawvel_clipped[0]), dim=-1)) + + yaw_earlier = yaw_full[..., :-1, :] + vx = v_avg * torch.cos(yaw_earlier) + vy = v_avg * torch.sin(yaw_earlier) + v_all = torch.cat((vx, vy), dim=-1) + + # print('initial_states[0, -2:]', initial_states[0, -2:]) + # print('vx[0, :5]', vx[0, :5]) + + v_all_paded = torch.cat( + (x0[..., :2].unsqueeze(-2), v_all * self.dt), dim=-2 + ) + x_and_y = torch.matmul(mat, v_all_paded) + x_and_y = x_and_y[..., 1:, :] + + x_all = torch.cat((x_and_y, v, yaw), dim=-1) + return x_all + + # def propagate_and_linearize(self,x0,u,dt=None): + # if dt is None: + # dt = self.dt + # xp,_,_ = self.forward_dynamics(x0,u,dt,mode="chain") + # xl = torch.cat([x0.unsqueeze(1),xp[:,:-1]],1) + # A,B = jacobian(lambda x,u: self.step(x,u,dt),(xl,u)) + # A = A.diagonal(dim1=0,dim2=3).diagonal(dim1=0,dim2=2).permute(2,3,0,1) + # B = B.diagonal(dim1=0,dim2=3).diagonal(dim1=0,dim2=2).permute(2,3,0,1) + # C = xp - (A@xl.unsqueeze(-1)+B@u.unsqueeze(-1)).squeeze(-1) + # return xp,A,B,C + + +class Unicycle_xyvsc(Dynamics): + def __init__( + self, dt,name = None, max_steer=0.5, max_yawvel=8, acce_bound=[-6, 4], vbound=[-10, 30] + ): + self.dt = dt + self._name = name + self._type = DynType.UNICYCLE + self.xdim = 5 + self.udim = 2 + self.acce_bound = acce_bound + self.vbound = vbound + self.max_steer = max_steer + self.max_yawvel = max_yawvel + + def step(self, x, u, bound=True, return_jacobian=False): + assert x.shape[:-1] == u.shape[:-1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + + s = x[..., 3:4] + c = x[..., 4:5] + c_step = c-0.5*s*u[...,1:]*self.dt + s_step = s+0.5*c*u[...,1:]*self.dt + vel_step = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + yaw = u[..., 1:2] + cp = c*np.cos(yaw*self.dt)-s*np.sin(yaw*self.dt) + sp = s*np.cos(yaw*self.dt)+c*np.sin(yaw*self.dt) + dx = np.hstack( + ( + c_step * vel_step*self.dt, + s_step * vel_step*self.dt, + u[...,]*self.dt, + sp-s, + cp-c, + ) + ) + xp = x + dx + if return_jacobian: + raise NotImplementedError + # d_cos_theta_p_d_theta = -np.sin(theta)-0.5*np.cos(theta)*u[...,1:]*self.dt + # d_sin_theta_p_d_theta = np.cos(theta)-0.5*np.sin(theta)*u[...,1:]*self.dt + # d_vel_p_d_a = 0.5*self.dt + # d_cos_theta_p_d_yaw = -0.5*np.sin(theta)*self.dt + # d_sin_theta_p_d_yaw = 0.5*np.cos(theta)*self.dt + # jacx = np.tile(np.eye(4), (*x.shape[:-1], 1, 1)) + # jacx[...,0,2:3] = cos_theta_p*self.dt + # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # jacx[...,1,2:3] = sin_theta_p*self.dt + # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + # jacu = np.zeros((*x.shape[:-1], 4, 2)) + # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # jacu[...,2,0:1] = self.dt + # jacu[...,3,1:2] = self.dt + + # return xp, jacx, jacu + else: + return xp + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + if bound: + lb, ub = self.ubound(x) + # s = (u - lb) / torch.clip(ub - lb, min=1e-3) + # u = lb + (ub - lb) * torch.sigmoid(s) + u = torch.clip(u, lb, ub) + + s = x[..., 3:4] + c = x[..., 4:5] + c_step = c-0.5*s*u[...,1:]*self.dt + s_step = s+0.5*c*u[...,1:]*self.dt + vel_step = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + yaw = u[..., 1:2] + cp = c*torch.cos(yaw*self.dt)-s*torch.sin(yaw*self.dt) + sp = s*torch.cos(yaw*self.dt)+c*torch.sin(yaw*self.dt) + dx = torch.cat( + ( + c_step * vel_step*self.dt, + s_step * vel_step*self.dt, + u[...,:1]*self.dt, + sp-s, + cp-c, + ),-1 + ) + xp = x + dx + if return_jacobian: + raise NotImplementedError + # d_cos_theta_p_d_theta = -torch.sin(theta)-0.5*torch.cos(theta)*u[...,1:]*self.dt + # d_sin_theta_p_d_theta = torch.cos(theta)-0.5*torch.sin(theta)*u[...,1:]*self.dt + # d_vel_p_d_a = 0.5*self.dt + # d_cos_theta_p_d_yaw = -0.5*torch.sin(theta)*self.dt + # d_sin_theta_p_d_yaw = 0.5*torch.cos(theta)*self.dt + # eye4 = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # jacxy = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # zeros21 = torch.zeros((*x.shape[:-1], 2, 1),device=x.device) + # jacv = torch.cat([cos_theta_p.unsqueeze(-2),sin_theta_p.unsqueeze(-2),zeros21],-2)*self.dt + # jactheta = torch.cat([(vel_p*d_cos_theta_p_d_theta).unsqueeze(-2),(vel_p*d_sin_theta_p_d_theta).unsqueeze(-2),zeros21],-2)*self.dt + # jacx = torch.cat([jacxy,jacv,jactheta],-1)+eye4 + # # jacx = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # # jacx[...,0,2:3] = cos_theta_p*self.dt + # # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # # jacx[...,1,2:3] = sin_theta_p*self.dt + # # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + + + # jacxy_a = torch.cat([cos_theta_p.unsqueeze(-2),sin_theta_p.unsqueeze(-2)],-2)*self.dt*d_vel_p_d_a + # jacxy_yaw = torch.cat([(vel_p*d_cos_theta_p_d_yaw).unsqueeze(-2),(vel_p*d_sin_theta_p_d_yaw).unsqueeze(-2)],-2)*self.dt + # eye2 = torch.tile(torch.eye(2,device=x.device), (*x.shape[:-1], 1, 1)) + # jacu = torch.cat([torch.cat([jacxy_a,jacxy_yaw],-1),eye2*self.dt],-2) + # # jacu = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # # jacu[...,2,0:1] = self.dt + # # jacu[...,3,1:2] = self.dt + return xp, jacx, jacu + else: + return xp + else: + raise NotImplementedError + + # def get_x_Gaussian_from_u(self,x,mu_u,var_u): + # mu_x,_,jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + # var_u_mat = torch.diag_embed(var_u) + # var_x = torch.matmul(torch.matmul(jacu,var_u_mat),jacu.transpose(-1,-2)) + # return mu_x,var_x + + def __call__(self, x, u): + return self.step(x, u) + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + v = x[..., 2:3] + vclip = np.clip(np.abs(v), a_min=0.1, a_max=None) + yawbound = np.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + acce_lb = np.clip( + np.clip(self.vbound[0] - v, None, self.acce_bound[1]), + self.acce_bound[0], + None, + ) + acce_ub = np.clip( + np.clip(self.vbound[1] - v, self.acce_bound[0], None), + None, + self.acce_bound[1], + ) + lb = np.concatenate((acce_lb, -yawbound),-1) + ub = np.concatenate((acce_ub, yawbound),-1) + return lb, ub + elif isinstance(x, torch.Tensor): + v = x[..., 2:3] + vclip = torch.clip(torch.abs(v),min=0.1) + yawbound = torch.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + yawbound = torch.clip(yawbound, min=0.1) + acce_lb = torch.clip( + torch.clip(self.vbound[0] - v, max=self.acce_bound[1]), + min=self.acce_bound[0], + ) + acce_ub = torch.clip( + torch.clip(self.vbound[1] - v, min=self.acce_bound[0]), + max=self.acce_bound[1], + ) + lb = torch.cat((acce_lb, -yawbound), dim=-1) + ub = torch.cat((acce_ub, yawbound), dim=-1) + return lb, ub + + else: + raise NotImplementedError + def uniform_sample_xp(self,x,num_sample): + if isinstance(x,torch.Tensor): + u_lb,u_ub = self.ubound(x) + u_sample = torch.rand(*x.shape[:-1],num_sample,self.udim).to(x.device)*(u_ub-u_lb).unsqueeze(-2)+u_lb.unsqueeze(-2) + xp = self.step(x.unsqueeze(-2).repeat_interleave(num_sample,-2),u_sample,bound=False) + elif isinstance(x,np.ndarray): + u_lb,u_ub = self.ubound(x) + u_sample = np.random.uniform(u_lb[...,None,:],u_ub[...,None,:],(*x.shape[:-1],num_sample,self.udim)) + xp = self.step(x[...,None,:].repeat(num_sample,-2),u_sample,bound=False) + else: + raise NotImplementedError + return xp + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2sc(x): + return x[..., 3:] + + @staticmethod + def state2yaw(x): + arctanfun = GeoUtils.ratan2 if isinstance(x,torch.Tensor) else np.arctan2 + return arctanfun(x[...,3:4],x[...,4:5]) + + @staticmethod + def state2vel(x): + return x[..., 2:3] + + @staticmethod + def combine_to_state(xy,vel,yaw): + if isinstance(xy,torch.Tensor): + return torch.cat((xy,vel,torch.sin(yaw),torch.cos(yaw)),-1) + elif isinstance(xy,np.ndarray): + return np.concatenate((xy,vel,np.sin(yaw),np.cos(yaw)),-1) + + def calculate_vel(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + if isinstance(pos, torch.Tensor): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * np.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * np.sin(yaw[..., 1:, :]) + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r,-1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r),-1) * vel_l + + np.expand_dims(mask_r & (~mask_l),-1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def inverse_dyn(self,x,xp,dt=None): + if dt is None: + dt = self.dt + acce = (xp[2:3]-x[2:3])/dt + arctanfun = GeoUtils.ratan2 if isinstance(x,torch.Tensor) else np.arctan2 + catfun = torch.cat if isinstance(x,torch.Tensor) else np.concatenate + yawrate = (arctanfun(xp[3:4],xp[4:5])-arctanfun(x[3:4],x[4:5]))/dt + return catfun([acce,yawrate],-1) + + + def get_state(self,pos,yaw,mask,dt=None): + if dt is None: + dt = self.dt + vel = self.calculate_vel(pos, yaw, mask,dt) + return self.combine_to_state(pos,vel,yaw) + + def forward_dynamics(self, + x0: torch.Tensor, + u: torch.Tensor, + mode="chain", + bound = True, + ): + + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + actions (Torch.tensor): action tensor of size [B, (A), T, 2] + Returns: + state tensor of size [B, (A), T, 4] + """ + if mode=="chain": + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :],bound=bound) + + return torch.stack(x[1:], dim=-2) + + else: + raise NotImplementedError + + + + + + +def test(): + model = Unicycle(0.1) + x0 = torch.tensor([[1, 2, 3, 4]]).repeat_interleave(3, 0) + u = torch.tensor([[1, 2]]).repeat_interleave(3, 0) + x, jacx, jacu = model.step(x0, u, return_jacobian=True) + + +if __name__ == "__main__": + test() diff --git a/diffstack/models/CTT.py b/diffstack/models/CTT.py new file mode 100644 index 0000000..d6fc3b2 --- /dev/null +++ b/diffstack/models/CTT.py @@ -0,0 +1,2758 @@ +import enum +from collections import OrderedDict, defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from trajdata.data_structures import AgentType + +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.lane_utils as LaneUtils +import diffstack.utils.model_utils as ModelUtils +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.dynamics.unicycle import Unicycle, Unicycle_xyvsc +from diffstack.models.base_models import MLP +from diffstack.models.RPE_simple import sAuxRPEAttention, sAuxRPECrossAttention +from diffstack.models.TypeTransformer import * +from diffstack.utils.diffusion_utils import zero_module +from diffstack.utils.dist_utils import categorical_psample_wor +from diffstack.utils.geometry_utils import ratan2 +from diffstack.utils.homotopy import HOMOTOPY_THRESHOLD, HomotopyType + + +class FeatureAxes(enum.Enum): + """ + Axes of the features + """ + + B = enum.auto() # batch + A = enum.auto() # agent + T = enum.auto() # time + L = enum.auto() # lanes + F = enum.auto() # features + + +class TFvars(enum.Enum): + """ + Variables of the transformer + """ + + Agent_hist = enum.auto() # Agent history trajectories + Agent_future = enum.auto() # Agent future trajectories + Lane = enum.auto() # Lanes + + +class GNNedges(enum.Enum): + """edges of the GNN""" + + Agenthist2Lane = enum.auto() + Agentfuture2Lane = enum.auto() + Agenthist2Agenthist = enum.auto() + Agentfuture2Agentfuture = enum.auto() + Lane2Lane = enum.auto() + + +class FactorizedAttentionBlock(nn.Module): + def __init__( + self, + n_embd: int, + n_head: int, + PE_mode: str, + use_rpe_net: bool, + attn_attributes: OrderedDict, + var_axes: dict, + attn_pdrop: float = 0, + resid_pdrop: float = 0, + nominal_axes_order: list = None, + MAX_T=50, + ): + """a factorized attention block + + Args: + n_embd (int): embedding dimension + n_head (int): number of heads + PE_mode (str): Positional embedding mode, "RPE" or "PE" + use_rpe_net (bool): For RPE, whether use RPE (True) or ProdPENet (False) + attn_attributes (OrderedDict): recipe for attention blocks + var_axes (dict): variables and their axes + attn_pdrop (float, optional): dropout rate for attention. Defaults to 0. + resid_pdrop (float, optional): dropout rate for residue connection. Defaults to 0. + nominal_axes_order (list, optional): axes' order when arranging tensors. Defaults to None. + MAX_T (int, optional): maximum teporal axis length for PE. Defaults to 50. + + """ + super().__init__() + + self.vars = list(var_axes.keys()) + self.var_axes = dict() + for var in self.vars: + self.var_axes[var] = [ + FeatureAxes[var_axis] for var_axis in var_axes[var].split(",") + ] + if nominal_axes_order is None: + nominal_axes_order = list([var for var in FeatureAxes]) + self.nominal_axes_order = nominal_axes_order + self.attn_attributes = attn_attributes + self.attn = nn.ModuleDict() + self.mlp = nn.ModuleDict() + self.bn_1 = nn.ModuleDict() + self.bn_2 = nn.ModuleDict() + # key of attn_attributes are the two variables and the attention axis, + # value of attn_attributes are the edge dimension, edge function (if defined), ntype, and whether to normalize the embedding + for (var1, var2, axis), attributes in attn_attributes.items(): + edge_dim, edge_func, ntype, normalization = attributes + + if var1 == var2: + # self attention + attn_name = var1.name + "_" + var2.name + "_" + axis.name + if axis in [FeatureAxes.A, FeatureAxes.L]: + # agent/lane axis attention + attn_net = TypeSelfAttention( + ntype=ntype, + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + aux_edge_func=edge_func, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + elif axis == FeatureAxes.T: + # time axis attention + if edge_func is None: + if PE_mode == "RPE": + attn_net = sAuxRPEAttention( + n_embd=n_embd, + num_heads=n_head, + aux_vardim=edge_dim, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxSelfAttention( + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + PE_len=MAX_T, + ) + + else: + if PE_mode == "RPE": + attn_net = sAuxRPECrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxCrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + attn_pdrop, + resid_pdrop, + PE_len=MAX_T, + ) + + else: + raise NotImplementedError + if normalization: + self.bn_1[attn_name + "_" + var1.name] = nn.BatchNorm1d(n_embd) + self.bn_2[attn_name] = nn.BatchNorm1d(n_embd) + else: + # cross attention + attn_name = ( + var1.name + + "_" + + var2.name + + "_" + + axis[0].name + + "->" + + axis[1].name + ) + if axis == (FeatureAxes.A, FeatureAxes.L): + # cross attention between agent and lane + attn_net = TypeCrossAttention( + ntype=ntype, + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + aux_edge_func=edge_func, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + elif var1 == TFvars.Agent_future and var2 == TFvars.Agent_hist: + assert axis == (FeatureAxes.T, FeatureAxes.T) + if PE_mode == "RPE": + attn_net = sAuxRPECrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxCrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + attn_pdrop, + resid_pdrop, + PE_len=MAX_T, + ) + else: + raise NotImplementedError + if normalization: + self.bn_1[attn_name + "_" + var1.name] = nn.BatchNorm1d(n_embd) + self.bn_1[attn_name + "_" + var2.name] = nn.BatchNorm1d(n_embd) + self.bn_2[attn_name] = nn.BatchNorm1d(n_embd) + + self.attn[attn_name] = attn_net + self.mlp[attn_name] = tfmlp(n_embd, 4 * n_embd, resid_pdrop) + + def generate_agent_attn_mask(self, agent_mask, ntype): + # return 2 attention masks for self, and neightbors, respectively, or only return neighbor mask + # agent_mask: [B, N] + assert ntype in [1, 2] + B, N = agent_mask.shape + cross_mask = agent_mask[:, :, None] * agent_mask[:, None] # [B,N,N] + # mask to attend to self + I = torch.eye(N).to(agent_mask.device)[None] + notI = torch.logical_not(I) + + self_mask = cross_mask * I + # mask to attend to others + neighbor_mask = cross_mask * notI + if ntype == 1: + return neighbor_mask + else: + return self_mask, neighbor_mask + + def get_cross_mask( + self, var1, var2, axis, x1, x2, var_masks, cross_masks, ignore_var1_mask=True + ): + """get the attention mask for cross attention""" + if (var1, var2, axis) in cross_masks: + mask = cross_masks[(var1, var2, axis)] + else: + a1 = [a.name for a in self.var_axes[var1] if a != FeatureAxes.F] + a2 = [a.name for a in self.var_axes[var2] if a != FeatureAxes.F] + ag = [ + a.name + for a in self.nominal_axes_order + if a != FeatureAxes.F and ((a.name in a1) or (a.name in a2)) + ] + + # swap the attention axes to the last two + ag = [a for a in ag if a != axis[0].name and a != axis[1].name] + [ + axis[0].name, + axis[1].name, + ] + if axis[0] == axis[1]: + # deal with repeated axes + assert axis[1].name.swapcase() not in a1 + assert axis[0].name.swapcase() not in a2 + a2[a2.index(axis[1].name)] = axis[1].name.swapcase() + ag[-1] = ag[-1].swapcase() + cmd = "".join(a1) + "," + "".join(a2) + "->" + "".join(ag) + if var1 in var_masks and not ignore_var1_mask: + mask1 = var_masks[var1] + else: + mask1 = torch.ones_like(x1[..., 0]) + if var2 in var_masks: + mask2 = var_masks[var2] + else: + mask2 = torch.ones_like(x2[..., 0]) + mask = torch.einsum(cmd, mask1, mask2) + mask = mask.reshape(-1, *mask.shape[-2:]) + return mask + + def forward( + self, + vars, + aux_xs=None, + var_masks=dict(), + cross_masks=dict(), + frame_indices=dict(), + edges=dict(), + ): + if TFvars.Agent_hist in vars: + B, N, Th, D = vars[TFvars.Agent_hist].shape + if TFvars.Lane in vars: + B, L = vars[TFvars.Lane].shape[:2] + if TFvars.Agent_future in vars: + Tf = vars[TFvars.Agent_future].shape[2] + if TFvars.Agent_hist in vars: + assert N == vars[TFvars.Agent_future].shape[1] + + for (var1, var2, axis), attributes in self.attn_attributes.items(): + edge = edges[(var1, var2, axis)] if (var1, var2, axis) in edges else None + edge_dim, edge_func, ntype, normalization = attributes + if var1 == var2: + attn_name = var1.name + "_" + var2.name + "_" + axis.name + mlp = self.mlp[attn_name] + # self attention + + attn_axis_idx = self.var_axes[var1].index(axis) + x = vars[var1] + aux_x = aux_xs[var1] if var1 in aux_xs else None + if (var1, var2, axis) in cross_masks: + mask = cross_masks[(var1, var2, axis)] + elif var1 in var_masks: + # provide the feature mask instead of attention mask + mask = var_masks[var1] + mask = mask.transpose(attn_axis_idx, -1) + mask = mask.reshape(-1, mask.size(-1)) + + if axis in [FeatureAxes.A, FeatureAxes.L]: + mask = self.generate_agent_attn_mask(mask, ntype) + else: + mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + else: + mask = None + if attn_axis_idx != len(self.var_axes[var1]) - 2: + # permute the attention axis to second last + x = x.transpose(attn_axis_idx, -2) + aux_x = ( + aux_x.transpose(attn_axis_idx, -2) + if aux_x is not None + else None + ) + orig_shape = x.shape + if axis in [FeatureAxes.A, FeatureAxes.L]: + x = x.reshape(-1, orig_shape[-2], orig_shape[-1]) + aux_x = ( + aux_x.reshape(-1, aux_x.shape[-2], aux_x.shape[-1]) + if aux_x is not None + else None + ) + + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name](xn, aux_x, mask, edge=edge) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp(self.attn[attn_name](x, aux_x, mask, edge=edge)) + elif axis == FeatureAxes.T: + T = x.shape[-2] + x = x.reshape(-1, T, D) + aux_x = ( + aux_x.reshape(-1, T, aux_x.shape[-1]) + if aux_x is not None + else None + ) + frame_index = frame_indices[var1] + frame_index = frame_index.reshape(-1, T) + if isinstance(self.attn[attn_name], sAuxRPEAttention) or isinstance( + self.attn[attn_name], AuxSelfAttention + ): + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name]( + xn, aux_x, mask, frame_indices=frame_index, edge=edge + ) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp( + self.attn[attn_name]( + x, aux_x, mask, frame_indices=frame_index, edge=edge + ) + ) + + elif isinstance( + self.attn[attn_name], sAuxRPECrossAttention + ) or isinstance(self.attn[attn_name], AuxCrossAttention): + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name]( + xn, + xn, + mask, + aux_x, + frame_index, + aux_x, + frame_index, + edge=edge, + ) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp( + self.attn[attn_name]( + x, + x, + mask, + aux_x, + aux_x, + frame_indices1=frame_index, + frame_indices2=frame_index, + edge=edge, + ) + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + x = x.reshape(*orig_shape) + if attn_axis_idx != len(self.var_axes[var1]) - 2: + # permute the attention axis back + x = x.transpose(attn_axis_idx, -2) + vars[var1] = x + else: + # cross attention + attn_name = ( + var1.name + + "_" + + var2.name + + "_" + + axis[0].name + + "->" + + axis[1].name + ) + mlp = self.mlp[attn_name] + x1 = vars[var1] + x2 = vars[var2] + aux_x1 = aux_xs[var1] if var1 in aux_xs else None + aux_x2 = aux_xs[var2] if var2 in aux_xs else None + mask = self.get_cross_mask( + var1, var2, axis, x1, x2, var_masks, cross_masks + ) + + if ( + var1 in [TFvars.Agent_hist, TFvars.Agent_future] + and var2 == TFvars.Lane + ): + # cross attention between agent and lane + assert x1.ndim == 4 # B,N,T,D + assert x2.ndim == 3 # B,L,D + T = x1.shape[2] + x1 = x1.transpose(1, 2).reshape(-1, N, D) # BT,N,D + x2 = x2.repeat_interleave(T, 0) # BT,L,D + aux_x1 = aux_x1.transpose(1, 2).reshape(B * T, N, -1) + aux_x2 = aux_x2.repeat_interleave(T, 0) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, xn2, mask, aux_x1, aux_x2, edge=edge + ) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = x1 + mlp( + self.attn[attn_name]( + x1, x2, mask, aux_x1, aux_x2, edge=edge + ) + ) + x1 = x1.reshape(B, T, N, D).transpose(1, 2) # B,N,T,D + elif var1 == TFvars.Lane and var2 in [ + TFvars.Agent_hist, + TFvars.Agent_future, + ]: + # cross attention between agent and lane + assert x2.ndim == 4 # B,N,T,D + assert x1.ndim == 3 # B,L,D + T = x2.shape[2] + L = x1.shape[1] + x2 = x2.transpose(1, 2).reshape(-1, N, D) # BT,N,D + x1 = x1.repeat_interleave(T, 0) # BT,L,D + aux_x2 = aux_x2.transpose(1, 2).reshape(-1, N, D) + aux_x1 = aux_x1.repeat_interleave(T, 0) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, xn2, mask, aux_x1, aux_x2, edge=edge + ) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = x1 + mlp( + self.attn[attn_name]( + x1, x2, mask, aux_x1, aux_x2, edge=edge + ) + ) + x1 = x1.reshape(B, T, L, D).max(1)[0] # B,L,D + elif var1 == TFvars.Agent_future and var2 == TFvars.Agent_hist: + assert axis == (FeatureAxes.T, FeatureAxes.T) + x2 = vars[TFvars.Agent_hist].reshape(B * N, Th, D) + x1 = vars[TFvars.Agent_future].reshape(B * N, Tf, D) + aux_x2 = aux_xs[TFvars.Agent_hist].reshape(B * N, Th, -1) + aux_x1 = aux_xs[TFvars.Agent_future].reshape(B * N, Tf, -1) + frame_indices1 = frame_indices[TFvars.Agent_future].reshape( + B * N, Tf + ) + frame_indices2 = frame_indices[TFvars.Agent_hist].reshape(B * N, Th) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, + xn2, + mask, + aux_x1, + aux_x2, + frame_indices1, + frame_indices2, + edge=edge, + ).reshape(B, N, Tf, D) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = mlp( + self.attn[attn_name]( + x1, + x2, + mask, + aux_x1, + aux_x2, + frame_indices1, + frame_indices2, + edge=edge, + ) + ).reshape(B, N, Tf, D) + + else: + raise NotImplementedError + vars[var1] = x1 + return vars + + +class FactorizedGNN(nn.Module): + def __init__( + self, + var_axes: dict, + edge_var: dict, + GNN_attributes: OrderedDict, + node_n_embd: int, + edge_n_embd: int, + nominal_axes_order: list = None, + ): + """a factorized GNN + + Args: + var_axes (dict): variables and their axes + edge_var (dict): edges -> the variables it connects + GNN_attributes (OrderedDict): recipe for GNN + node_n_embd (int): embedding dimension for node variables + edge_n_embd (int): embedding dimension for edge variables + nominal_axes_order (list, optional): the order of axes when arranging tensors. Defaults to None. + """ + super().__init__() + self.vars = list(var_axes.keys()) + self.var_axes = dict() + for var in self.vars: + self.var_axes[var] = [ + FeatureAxes[var_axis] for var_axis in var_axes[var].split(",") + ] + self.edge_var = edge_var + + if nominal_axes_order is None: + nominal_axes_order = list([var for var in FeatureAxes]) + self.nominal_axes_order = nominal_axes_order + self.node_n_embd = node_n_embd + self.edge_n_embd = edge_n_embd + + self.GNN_attributes = GNN_attributes + self.GNN_nets = nn.ModuleDict() + self.node_attn = nn.ParameterDict() + self.pe_net = nn.ParameterDict() + self.batch_norm = nn.ModuleDict() + + for (edge, gtype, var), attributes in self.GNN_attributes.items(): + hidden_dim, activation, pooling_method, Nmax = attributes + + if gtype == "edge": + net_name = f"edge_{edge.name}" + var1, var2 = self.edge_var[edge] + nd1, nd2 = self.node_n_embd[var1], self.node_n_embd[var2] + ed = self.edge_n_embd[edge] + # message udpate of edge + + self.GNN_nets[net_name] = MLP( + input_dim=nd1 + nd2 + ed, + output_dim=ed, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=activation if activation is not None else nn.ReLU, + dropouts=None, + normalization=False, + ) + elif gtype == "node": + # message update of node + net_name = f"node_{edge.name}_{var.name}" + ed = self.edge_n_embd[edge] + nd = self.node_n_embd[var] + # message udpate of edge + + self.GNN_nets[net_name] = MLP( + input_dim=nd + ed, + output_dim=nd, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=activation if activation is not None else nn.ReLU, + dropouts=None, + normalization=False, + ) + + if pooling_method == "attn": + self.node_attn[net_name] = CrossAttention(nd, n_head=4) + self.pe_net[net_name] = ( + nn.Parameter(torch.randn(Nmax, nd)) + if Nmax is not None + else None + ) + + def pooling(self, node_feat, node_feat_new, mask, net_name, pooling_method): + """pooling the node features when aggregating the messages for node varaibles""" + # node_feat: [B,N,node_dim] + # node_feat_new: [B,N,N,node_dim] + if pooling_method == "max": + node_feat_new.masked_fill_( + mask[..., None] == 0, node_feat_new.min().detach() - 1 + ) + node_feat_new = node_feat_new.max(2)[0] + elif pooling_method == "mean": + node_feat_new = (node_feat_new * mask[..., None]).sum(2) / mask.sum(2)[ + ..., None + ].clip(min=1e-3) + elif pooling_method == "attn": + N, D = node_feat_new.shape[-2:] + if self.pe_net[net_name] is not None: + node_feat_new = node_feat_new + self.pe_net[net_name][None, None, :N] + node_feat_new = self.node_attn[net_name]( + node_feat.reshape(-1, 1, D), + node_feat_new.reshape(-1, N, D), + mask.reshape(-1, 1, N), + ).reshape(*node_feat.shape) + + return node_feat_new + + def get_cross_mask( + self, var1, var2, edge, x1, x2, var_masks, cross_masks, axis=None + ): + """get the mask for message passing""" + if edge in cross_masks: + mask = cross_masks[edge] + else: + a1 = [a.name for a in self.var_axes[var1] if a != FeatureAxes.F] + a2 = [a.name for a in self.var_axes[var2] if a != FeatureAxes.F] + ag = [ + a.name + for a in self.nominal_axes_order + if a != FeatureAxes.F and ((a.name in a1) or (a.name in a2)) + ] + if var1 == var2: + # self mask + # need to provide the axis on which the mask is applied + assert axis is not None + assert axis.name.swapcase() not in a1 + assert axis.name.swapcase() not in a2 + a2[a2.index(axis.name)] = axis.name.swapcase() + ag.insert(ag.index(axis.name) + 1, axis.name.swapcase()) + + cmd = "".join(a1) + "," + "".join(a2) + "->" + "".join(ag) + if var1 in var_masks: + mask1 = var_masks[var1] + else: + mask1 = torch.ones_like(x1[..., 0]) + if var2 in var_masks: + mask2 = var_masks[var2] + else: + mask2 = torch.ones_like(x2[..., 0]) + mask = torch.einsum(cmd, mask1, mask2) + return mask + + def forward(self, vars, var_masks, cross_masks): + nodes = {var: x for var, x in vars.items() if var in TFvars} + edges = {edge: x for edge, x in vars.items() if edge in GNNedges} + for (edge, gtype, var), attributes in self.GNN_attributes.items(): + _, _, pooling_method, _ = attributes + if gtype == "edge": + net_name = f"edge_{edge.name}" + var1, var2 = self.edge_var[edge] + nx1, nx2 = nodes[var1], nodes[var2] + ex = edges[edge] + if edge in [GNNedges.Agenthist2Lane, GNNedges.Agentfuture2Lane]: + B, N, T = nx1.shape[:3] + M = nx2.size(1) + aggr_feature = torch.cat( + [ + nx1.unsqueeze(3).expand(B, N, T, M, -1), + nx2.view(B, 1, 1, M, -1).expand(B, N, T, M, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + elif edge in [ + GNNedges.Agenthist2Agenthist, + GNNedges.Agentfuture2Agentfuture, + ]: + B, N, T = nx1.shape[:3] + aggr_feature = torch.cat( + [ + nx1.unsqueeze(2).expand(B, N, N, T, -1), + nx1.unsqueeze(1).expand(B, N, N, T, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + elif edge == GNNedges.Lane2Lane: + B, M = nx1.shape[:2] + aggr_feature = torch.cat( + [ + nx1.unsqueeze(2).expand(B, M, M, -1), + nx1.unsqueeze(3).expand(B, M, M, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + else: + raise NotImplementedError + elif gtype == "node": + net_name = f"node_{edge.name}_{var.name}" + ex = edges[edge] + var1, var2 = self.edge_var[edge] + nx1, nx2 = nodes[var1], nodes[var2] + nx = nodes[var] + + if edge in [GNNedges.Agenthist2Lane, GNNedges.Agentfuture2Lane]: + mask = self.get_cross_mask( + var1, var2, edge, nx1, nx2, var_masks, cross_masks + ) + if var in [TFvars.Agent_future, TFvars.Agent_hist]: + B, N, T = nx.shape[:3] + aggr_feature = torch.cat( + [nx.unsqueeze(3).expand(B, N, T, M, -1), ex], dim=-1 + ) + + new_nx = self.GNN_nets[net_name](aggr_feature) + + nodes[var] = nodes[var] + self.pooling( + nx.reshape(B, N * T, -1), + new_nx.view(B, N * T, M, -1), + mask.view(B, N * T, M), + net_name, + pooling_method=pooling_method, + ).view(B, N, T, -1) + elif var == TFvars.Lane: + B, M = nx.shape[:2] + aggr_feature = torch.cat( + [nx[:, None, None].expand(B, N, T, M, -1), ex], dim=-1 + ) + + new_nx = self.GNN_nets[net_name](aggr_feature) # [B,N,T,M,D] + nodes[var] = nodes[var] + self.pooling( + nx, + new_nx.view(B, N * T, M, -1).transpose(1, 2), + mask.view(B, N * T, M).transpose(1, 2), + net_name, + pooling_method=pooling_method, + ) + + elif edge in [ + GNNedges.Agenthist2Agenthist, + GNNedges.Agentfuture2Agentfuture, + ]: + mask = self.get_cross_mask( + var1, + var2, + edge, + nx1, + nx2, + var_masks, + cross_masks, + axis=FeatureAxes.A, + ) + B, N, T = nx.shape[:3] + aggr_feature = torch.cat( + [nx.unsqueeze(2).expand(B, N, N, T, -1), ex], dim=-1 + ) + new_nx = self.GNN_nets[net_name](aggr_feature) # [B,N,N,T,D] + nodes[var] = nodes[var] + self.pooling( + nx.transpose(1, 2).reshape(B * T, N, -1), + new_nx.permute(0, 3, 1, 2, 4).reshape(B * T, N, N, -1), + mask.permute(0, 3, 1, 2).reshape(B * T, N, N), + net_name, + pooling_method=pooling_method, + ).view(B, T, N, -1).transpose(1, 2) + + elif edge == GNNedges.Lane2Lane: + raise NotImplementedError + + return {**nodes, **edges} + + +class tfmlp(nn.Module): + def __init__(self, n_embd, ff_dim, dropout): + """An MLP with GELU activation and dropout for transformer residual connection + + Args: + n_embd (int): embedding dimension + ff_dim (int): feed forward dimension + dropout (float): dropout rate + """ + super().__init__() + self.c_fc = nn.Linear(n_embd, ff_dim) + self.c_proj = zero_module(nn.Linear(ff_dim, n_embd)) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dropout(self.c_proj(self.act(self.c_fc(x)))) + + +class CTTBlock(nn.Module): + def __init__(self, TF_kwargs, GNN_kwargs): + """a CTT block with a transformer block and a GNN block""" + super().__init__() + self.TFblock = ( + FactorizedAttentionBlock(**TF_kwargs) if TF_kwargs is not None else None + ) + self.GNNblock = FactorizedGNN(**GNN_kwargs) if GNN_kwargs is not None else None + + def forward( + self, vars, aux_xs, var_masks, cross_masks=None, frame_indices=None, edges={} + ): + if self.TFblock is not None: + vars = self.TFblock( + vars, aux_xs, var_masks, cross_masks, frame_indices, edges + ) + if self.GNNblock is not None: + vars = self.GNNblock(vars, var_masks, cross_masks) + return vars + + +class CTTEncoder(nn.Module): + def __init__(self, nblock, TF_kwargs, GNN_kwargs): + super().__init__() + self.blocks = nn.ModuleList( + [CTTBlock(TF_kwargs, GNN_kwargs) for _ in range(nblock)] + ) + + def forward(self, vars, enc_kwargs): + for block in self.blocks: + vars = block(vars, **enc_kwargs) + return vars + + +class CTT(nn.Module): + def __init__( + self, + n_embd: int, + embed_funcs: dict, + enc_nblock: int, + dec_nblock: int, + enc_transformer_kwargs: dict, + enc_GNN_kwargs: dict, + dec_transformer_kwargs: dict, + dec_GNN_kwargs: dict, + enc_output_params: dict, + dec_output_params: dict, + hist_lane_relation, + fut_lane_relation, + max_joint_cardinality: int, + classify_a2l_4all_lanes: bool = False, + edge_func=dict(), + ): + """Categorical Traffic Transformer + + Args: + n_embd (int): embedding dimensions + embed_funcs (dict): embedding functions for variables + enc_nblock (int): number of blocks in encoder + dec_nblock (int): number of blocks in decoder + enc_transformer_kwargs (dict): recipe for transformer in CTT encoder + enc_GNN_kwargs (dict): recipe for GNN in CTT encoder + dec_transformer_kwargs (dict): recipe for transformer in CTT decoder + dec_GNN_kwargs (dict): recipe for GNN in CTT decoder + enc_output_params (dict): parameters for mode prediction of CTT encoder + dec_output_params (dict): parameters for trajectory prediction of CTT decoder + hist_lane_relation: class of lane relation between agent history and lane segments + fut_lane_relation: class of lane relation between agent future and lane segments + max_joint_cardinality (int): maximum number of cardinality for each factor during importance sampling for joint scene mode + classify_a2l_4all_lanes (bool, optional): whether to predict lane mode for every agent-lane pair. Defaults to False. + edge_func (dict, optional): edge function between node variables. Defaults to dict(). + """ + super().__init__() + # build transformer encoder and decoder + self.embed_funcs = nn.ModuleDict({k.name: v for k, v in embed_funcs.items()}) + self.edge_var = {} + if enc_GNN_kwargs is not None: + self.edge_var.update(enc_GNN_kwargs["edge_var"]) + if dec_GNN_kwargs is not None: + self.edge_var.update(dec_GNN_kwargs["edge_var"]) + self.encoder = CTTEncoder(enc_nblock, enc_transformer_kwargs, enc_GNN_kwargs) + self.decoder = CTTEncoder(dec_nblock, dec_transformer_kwargs, dec_GNN_kwargs) + self.n_embd = n_embd + self.max_joint_cardinality = max_joint_cardinality + self.hist_lane_relation = hist_lane_relation + self.fut_lane_relation = fut_lane_relation + self.classify_a2l_4all_lanes = classify_a2l_4all_lanes + self.build_enc_pred_net(enc_output_params) + self.build_dec_output_net(dec_output_params) + self.use_hist_mode = False + self.enc_output_params = enc_output_params + self.dec_output_params = dec_output_params + self.enc_vars = ( + [var for var in enc_transformer_kwargs["var_axes"].keys()] + if enc_transformer_kwargs is not None + else [] + ) + self.enc_edges = ( + [edge for edge in enc_GNN_kwargs["edge_var"].keys()] + if enc_GNN_kwargs is not None + else [] + ) + self.dec_vars = ( + [var for var in dec_transformer_kwargs["var_axes"].keys()] + if dec_transformer_kwargs is not None + else [] + ) + self.dec_edges = ( + [edge for edge in dec_GNN_kwargs["edge_var"].keys()] + if dec_GNN_kwargs is not None + else [] + ) + self.edge_func = edge_func + self.dec_T = self.Tf + + def build_enc_pred_net(self, enc_output_params, use_hist_mode=False): + # Nlr = len(self.fut_lane_relation) + # Nhm = len(HomotopyType) + self.Th = enc_output_params["Th"] + self.pooling_T = enc_output_params["pooling_T"] + self.null_lane_mode = enc_output_params["null_lane_mode"] + mode_embed_dim = enc_output_params.get("mode_embed_dim", 64) + self.lm_embed = nn.Embedding(len(self.fut_lane_relation), mode_embed_dim) + self.homotopy_embed = nn.Embedding(len(HomotopyType), mode_embed_dim) + hidden_dim = ( + enc_output_params["lane_mode"]["hidden_dim"] + if "hidden_dim" in enc_output_params["lane_mode"] + else [self.n_embd * 4, self.n_embd * 2] + ) + # marginal probability of lane relation + self.lane_mode_net = MLP( + input_dim=self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + output_dim=len(self.fut_lane_relation) + - ( + 1 - int(self.classify_a2l_4all_lanes) + ), # If loss over lanes, no NOTON relation + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + if self.pooling_T == "attn": + n_head = enc_output_params["lane_mode"]["n_head"] + if enc_output_params["PE_mode"] == "RPE": + self.lane_mode_attn_T = sAuxRPECrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + + self.homotopy_attn_T = sAuxRPECrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + + self.agent_hist_attn_T = sAuxRPECrossAttention( + self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + elif enc_output_params["PE_mode"] == "PE": + self.lane_mode_attn_T = AuxCrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + self.homotopy_attn_T = AuxCrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + self.agent_hist_attn_T = AuxCrossAttention( + self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + + hidden_dim = ( + enc_output_params["homotopy"]["hidden_dim"] + if "hidden_dim" in enc_output_params["homotopy"] + else [self.n_embd * 4, self.n_embd * 2] + ) + + # marginal probability of homotopy + self.homotopy_net = MLP( + input_dim=self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + output_dim=len(HomotopyType), + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + GNN_kwargs = enc_output_params["joint_mode"]["GNN_kwargs"] + jm_GNN_nblock = enc_output_params["joint_mode"].get("jm_GNN_nblock", 2) + + self.JM_GNN = nn.ModuleList( + [FactorizedGNN(**GNN_kwargs) for i in range(jm_GNN_nblock)] + ) + hidden_dim = enc_output_params["joint_mode"].get("hidden_dim", [256, 128]) + self.JM_lane_mode_factor = MLP( + input_dim=self.n_embd + mode_embed_dim * 2 + if use_hist_mode + else self.n_embd + mode_embed_dim, + output_dim=1, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + self.JM_homotopy_factor = MLP( + input_dim=self.n_embd + mode_embed_dim * 2 + if use_hist_mode + else self.n_embd + mode_embed_dim, + output_dim=1, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + self.num_joint_samples = enc_output_params["joint_mode"].get( + "num_joint_samples", 30 + ) + self.num_joint_factors = enc_output_params["joint_mode"].get( + "num_joint_factors", 6 + ) + + def build_dec_output_net(self, dec_output_params): + arch = dec_output_params["arch"] + self.Tf = dec_output_params["Tf"] + self.arch = arch + dyn = dec_output_params.get("dyn", None) + self.dyn = dyn + self.dt = dec_output_params.get("dt", 0.1) + traj_dim = dec_output_params["traj_dim"] + self.decode_num_modes = dec_output_params.get("decode_num_modes", 1) + self.AR_step_size = dec_output_params.get("AR_step_size", 1) + self.AR_update_mode = dec_output_params.get("AR_update_mode", "step") + self.LR_sample_hack = dec_output_params.get("LR_sample_hack", False) + self.dec_rounds = dec_output_params.get("dec_rounds", 3) + assert self.Tf % self.AR_step_size == 0 + if arch == "lstm": + self.output_rnn = nn.ModuleDict() + num_layers = dec_output_params.get("num_layers", 1) + hidden_size = dec_output_params.get("lstm_hidden_size", self.n_embd) + for k, v in self.dyn.items(): + if v is not None: + proj_size = v.udim + else: + proj_size = traj_dim + self.output_rnn[k.name] = nn.LSTM( + self.n_embd, + hidden_size, + batch_first=True, + num_layers=num_layers, + proj_size=proj_size, + ) + elif arch == "mlp": + self.output_mlp = nn.ModuleDict() + hidden_dim = dec_output_params.get( + "mlp_hidden_dims", [self.n_embd * 2, self.n_embd * 4] + ) + for k, v in self.dyn.items(): + self.output_mlp[k.name] = MLP( + input_dim=self.n_embd, + output_dim=traj_dim if v is None else v.udim, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + def embed_raw(self, raw_vars, aux_xs, aux_edges={}): + vars = dict() + for kname, v in self.embed_funcs.items(): + if hasattr(TFvars, kname): + # embed a TFvar + k = getattr(TFvars, kname) + if k in raw_vars: + vars[k] = self.embed_funcs[k.name](raw_vars[k]) + + elif hasattr(GNNedges, kname): + k = getattr(GNNedges, kname) + var1, var2 = self.edge_var[k] + if var1 in aux_xs and var2 in aux_xs: + aux_edge = aux_edges.get(k, None) + vars[k] = self.embed_funcs[k.name]( + aux_xs[var1], aux_xs[var2], aux_edge + ) + return vars + + def predict_from_context( + self, + context_vars, + aux_xs, + frame_indices, + var_masks, + cross_masks, + enc_edges, + GT_lane_mode=None, + GT_homotopy=None, + prev_lm_pred_dict=defaultdict(lambda: None), + prev_homo_pred_dict=defaultdict(lambda: None), + num_samples=None, + GT_batch_mask=None, + ): + # GT_batch_mask: identify which indices (among batch dimension) are given the GT modes + + device = context_vars[TFvars.Agent_hist].device + if num_samples is None: + num_samples = self.num_joint_samples + # predict lane mode + agent_feat = context_vars[TFvars.Agent_hist] + lane_feat = context_vars[TFvars.Lane] + a2l_edge = context_vars[GNNedges.Agenthist2Lane] + a2a_edge = context_vars[GNNedges.Agenthist2Agenthist] + B, N, Th = agent_feat.shape[:3] + if GT_batch_mask is None and GT_lane_mode is not None: + GT_batch_mask = torch.ones(B, device=device, dtype=bool) + M = lane_feat.shape[1] + if self.use_hist_mode: + prev_lm_pred = [ + torch.zeros(B, N, M, dtype=torch.int, device=device) for t in range(Th) + ] + prev_homo_pred = [ + torch.zeros(B, N, N, dtype=torch.int, device=device) for t in range(Th) + ] + + for t, v in prev_lm_pred_dict.items(): + prev_lm_pred[t] = v + for t, v in prev_homo_pred_dict.items(): + if v is not None: + prev_homo_pred[t] = v + prev_lm_embed = self.lm_embed(torch.stack(prev_lm_pred, 2)) + prev_homo_embed = self.homotopy_embed(torch.stack(prev_homo_pred, 3)) + a2l_edge = torch.cat([a2l_edge, prev_lm_embed], -1) + a2a_edge = torch.cat([a2a_edge, prev_homo_embed], -1) + if self.pooling_T == "max": + a2l_edge_pool = a2l_edge.max(2)[0] + a2a_edge_pool = a2a_edge.max(3)[0] + agent_feat_pool = agent_feat.max(-2)[0] + + elif self.pooling_T == "attn": + query = a2l_edge[:, :, -1].reshape(B * N * M, 1, -1) + frame_indices2 = ( + frame_indices[TFvars.Agent_hist] + .reshape(B * N, -1) + .repeat_interleave(M, 0) + ) + frame_indices1 = frame_indices2[:, -1:] + hist_mask = var_masks[TFvars.Agent_hist] # B,N,T + mask = ( + (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)) + .repeat_interleave(M, 1) + .reshape(B * N * M, 1, Th) + ) + a2l_edge_pool = self.lane_mode_attn_T( + query, + a2l_edge.permute(0, 1, 3, 2, 4).reshape(B * N * M, Th, -1), + mask, + None, + None, + frame_indices1=frame_indices1, + frame_indices2=frame_indices2, + ).reshape(B, N, M, -1) + + query = a2a_edge[:, :, :, -1].reshape(B * N * N, 1, -1) + frame_indices2 = ( + frame_indices[TFvars.Agent_hist] + .reshape(B * N, -1) + .repeat_interleave(N, 0) + ) + frame_indices1 = frame_indices2[:, -1:] + mask = ( + (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)) + .repeat_interleave(N, 1) + .reshape(B * N * N, 1, Th) + ) + a2a_edge_pool = self.homotopy_attn_T( + query, + a2a_edge.reshape(B * N * N, Th, -1), + mask, + None, + None, + frame_indices1, + frame_indices2, + ).reshape(B, N, N, -1) + + query = agent_feat[:, :, -1].reshape(B * N, 1, -1) + frame_indices2 = frame_indices[TFvars.Agent_hist].reshape(B * N, -1) + frame_indices1 = frame_indices2[:, -1:] + mask = (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)).reshape( + B * N, 1, Th + ) + agent_feat_pool = self.agent_hist_attn_T( + query, + agent_feat.reshape(B * N, Th, -1), + mask, + None, + None, + frame_indices1, + frame_indices2, + ).view(B, N, -1) + + lane_mode_pred = self.lane_mode_net(a2l_edge_pool) + + if not self.classify_a2l_4all_lanes: # if self.per_mode + lane_mode_pred = lane_mode_pred.swapaxes( + -1, -2 + ) # [..., M, nbr_modes] -> [..., nbr_modes, M] + if self.null_lane_mode: + # add a null lane mode + lane_mode_pred = torch.cat( + [lane_mode_pred, torch.zeros_like(lane_mode_pred[..., 0:1])], -1 + ) + + lane_mode_prob = F.softmax(lane_mode_pred, dim=-1) + + homotopy_pred = self.homotopy_net(a2a_edge_pool) + homotopy_asym = homotopy_pred.transpose(1, 2) + homotopy_pred = (homotopy_pred + homotopy_asym) / 2 + + homotopy_prob = F.softmax(homotopy_pred, dim=-1) + + ########### Compute joint mode probability ############ + D_lane = lane_mode_pred.shape[ + -1 + ] # Either M (M+1) or len(self.fut_lane_relation) + D_homo = len(HomotopyType) + D = min([max([D_lane, D_homo]), self.max_joint_cardinality]) + M_lm = lane_mode_pred.shape[ + -2 + ] # number of possible modes for each lane,agent pair + + # factor masks for joint mode prediction + agent_mask = var_masks[TFvars.Agent_hist].any(-1).float() + lane_mask = var_masks[TFvars.Lane] + lm_factor_mask = agent_mask.unsqueeze(2) * lane_mask.unsqueeze(1) # (B,N,M) + if self.null_lane_mode and not self.classify_a2l_4all_lanes: + lane_mask_ext = torch.cat( + [lane_mask, torch.ones(*lane_mask.shape[:-1], 1, device=device)], -1 + ) + lm_factor_mask = torch.cat([lm_factor_mask, agent_mask.unsqueeze(-1)], -1) + else: + lane_mask_ext = lane_mask + + homo_factor_mask = agent_mask.unsqueeze(2) * agent_mask.unsqueeze(1) # (B,N,N) + if not self.classify_a2l_4all_lanes: + mode_mask = torch.ones( + [agent_mask.shape[0], M_lm], device=agent_mask.device + ) # All modes active + lm_factor_mask_per_mode = agent_mask.unsqueeze(2) * mode_mask.unsqueeze(1) + combined_factor_mask = torch.cat( + [ + lm_factor_mask_per_mode.view(B, N * M_lm), + homo_factor_mask.view(B, N * N), + ], + -1, + ) + else: + combined_factor_mask = torch.cat( + [lm_factor_mask.view(B, N * M_lm), homo_factor_mask.view(B, N * N)], -1 + ) + + if D_homo < D: + homotopy_pred_padded = torch.cat( + [ + homotopy_pred, + torch.ones(*homotopy_pred.shape[:-1], D - D_homo, device=device) + * -torch.inf, + ], + -1, + ) + elif D_homo > D: + raise NotImplementedError("We can only consider all homotopies") + else: + homotopy_pred_padded = homotopy_pred + + if ( + self.LR_sample_hack and GT_lane_mode is not None + ): # hack to favor sampling left and right lane of each agent + if ( + self.hist_lane_relation == LaneUtils.LaneRelation + and self.fut_lane_relation == LaneUtils.SimpleLaneRelation + ): + lane_mode_pred_m = lane_mode_pred.clone().detach() + try: + hist_lane_flag = enc_edges[ + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ][..., -len(LaneUtils.LaneRelation) :] + hist_lane_flag = hist_lane_flag.reshape(B, Th, N, M, -1)[ + :, -1 + ].type(torch.bool) + # pick out left and right lane and modify the probability + lm_logpi_max = lane_mode_pred_m.max(-1)[0] + left_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.LEFTOF + ].unsqueeze(2) + right_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.RIGHTOF + ].unsqueeze(2) + on_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.ON + ].unsqueeze(2) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~left_lane_mask) + + lm_logpi_max.unsqueeze(-1) * left_lane_mask + ) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~right_lane_mask) + + lm_logpi_max.unsqueeze(-1) * right_lane_mask + ) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~on_lane_mask) + + lm_logpi_max.unsqueeze(-1) * on_lane_mask + ) + except: + pass + + else: + lane_mode_pred_m = None + else: + lane_mode_pred_m = None + + if D_lane < D: + lane_mode_pred_padded = torch.cat( + [ + lane_mode_pred, + torch.ones(*lane_mode_pred.shape[:-1], D - D_lane, device=device) + * -torch.inf, + ], + -1, + ) + indices = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + elif D_lane > D: + # For each agent & mode, find the D likeliest lanes + assert ( + not self.classify_a2l_4all_lanes + ), "Currently only implemented for per_mode setting" + + lane_mode_pred_masked = lane_mode_pred.masked_fill( + torch.logical_not(lane_mask_ext[:, None, None, :]), -torch.inf + ) + + sorted, indices = lane_mode_pred_masked.topk( + D, dim=-1 + ) # Only consider most likely lanes! + lane_mode_pred_padded = torch.gather( + lane_mode_pred, -1, indices + ) # (...,M) -> (...,D) + else: + lane_mode_pred_padded = lane_mode_pred + indices = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + + combined_logpi = torch.cat( + [ + lane_mode_pred_padded.view(B, N * M_lm, -1), + homotopy_pred_padded.view(B, N * N, -1), + ], + 1, + ) + # for all invalid factors, turn the logpi to [0,-inf,-inf,-inf...] + combined_logpi[..., 1:].masked_fill_( + ~(combined_factor_mask.bool().unsqueeze(-1)), -torch.inf + ) + + if lane_mode_pred_m is not None: + if D_lane < D: + lane_mode_pred_padded_m = torch.cat( + [ + lane_mode_pred_m, + torch.ones( + *lane_mode_pred_m.shape[:-1], D - D_lane, device=device + ) + * -torch.inf, + ], + -1, + ) + indices_m = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + elif D_lane > D: + # For each agent & mode, find the D likeliest lanes + assert ( + not self.classify_a2l_4all_lanes + ), "Currently only implemented for per_mode setting" + lane_mode_pred_masked_m = lane_mode_pred_m.masked_fill( + torch.logical_not(lane_mask_ext[:, None, None, :]), -torch.inf + ) + sorted, indices_m = lane_mode_pred_masked_m.topk( + D, dim=-1 + ) # Only consider most likely lanes! + lane_mode_pred_padded_m = torch.gather( + lane_mode_pred_m, -1, indices_m + ) # (...,M) -> (...,D) + + else: + lane_mode_pred_padded_m = lane_mode_pred_m + indices_m = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + + combined_logpi_m = torch.cat( + [ + lane_mode_pred_padded_m.view(B, N * M_lm, -1), + homotopy_pred_padded.clone().detach().view(B, N * N, -1), + ], + 1, + ) + # for all invalid factors, turn the logpi to [0,-inf,-inf,-inf...] + combined_logpi_m[..., 1:].masked_fill_( + ~(combined_factor_mask.bool().unsqueeze(-1)), -torch.inf + ) + else: + combined_logpi_m = None + + # relevance score + # lane mode relevance score + # first, factors with one dominant mode is less relevant (or one dominant lane) + lm_dominance_score = -( + lane_mode_prob.max(dim=-1)[0] - 1 / lane_mode_prob.shape[-1] + ) # B,N,M_lm + homo_dominance_score = -( + homotopy_prob.max(dim=-1)[0] - 1 / homotopy_prob.shape[-1] + ) # B,N,N + # second, lanes that are far away from the agents are less important + a2ledge_raw = ModelUtils.agent2lane_edge_proj( + aux_xs[TFvars.Agent_hist][:, :, -1], aux_xs[TFvars.Lane] + ) + + if not self.classify_a2l_4all_lanes: + a2l_dis_score = 0 # For per mode, we don't consider lanes specifically so no a2l distance score + + else: + a2l_dis = a2ledge_raw[..., :2].norm(dim=-1) # B,N,M + a2l_dis_score = 1 / (a2l_dis + 1) + # lanes that are closer to ego is more important + + # third, agents far from ego is less relevant + a2eedge_raw = ModelUtils.agent2agent_edge( + aux_xs[TFvars.Agent_hist][:, 0:1, -1], aux_xs[TFvars.Agent_hist][:, :, -1] + ).squeeze(1) + a2e_dis = a2eedge_raw[..., :2].norm(dim=-1) + a2e_dis.masked_fill_(torch.logical_not(agent_mask.bool()), torch.inf) + a2e_dis_score = 1 / (a2e_dis.clip(min=0.5)) + a2e_dis_score_homo = a2e_dis_score.unsqueeze(1) * a2e_dis_score.unsqueeze(2) + a2e_dis_score_homo.masked_fill_( + torch.eye(N, device=a2e_dis_score.device, dtype=torch.bool).unsqueeze(0), 0 + ) + + lm_factor_score = ( + lm_dominance_score * 0.1 + + a2l_dis_score + + a2e_dis_score.unsqueeze(-1).repeat_interleave(M_lm, -1) * 2 + ) + homo_factor_score = homo_dominance_score + a2e_dis_score_homo + + # mask out half of homo_factor due to symmetry + sym_mask = torch.tril(torch.ones(B, N, N, dtype=torch.bool, device=device)) + homo_factor_score.masked_fill_(sym_mask, homo_factor_score.min().detach() - 5) + combined_factor_mask[:, -N * N :].masked_fill_(sym_mask.view(B, -1), 0) + if not self.classify_a2l_4all_lanes: + # Set the relevance score for ego agent (0) and the on lane index such that this is always a factor + max_score = torch.max(lm_factor_score.max(), homo_factor_score.max()) + on_lane_idx = ( + self.fut_lane_relation.ON - 1 + if self.fut_lane_relation.ON > self.fut_lane_relation.NOTON + else self.fut_lane_relation.ON + ) # We removed lane_pred + lm_factor_score[:, 0, on_lane_idx] = ( + 1.1 * max_score + ) # Temp, maybe make customizable + + combined_factor_score = torch.cat( + [lm_factor_score.view(B, N * M_lm), homo_factor_score.view(B, N * N)], -1 + ) + + num_factor = min(self.num_joint_factors, N * (M_lm + N)) + # modify the loglikelihood so that more important factors get more samples + temperature = 2 + modfied_combined_logpi = combined_logpi / torch.exp( + temperature + * (combined_factor_score - combined_factor_score.min()) + / (combined_factor_score.max() - combined_factor_score.min()) + ).unsqueeze(-1) + # Do importance sampling (sampling only the factors) the remaining marginals are chosen to be their argmax value + joint_sample, factor_idx = categorical_psample_wor( + modfied_combined_logpi, + num_samples, + num_factor, + factor_mask=combined_factor_mask, + relevance_score=combined_factor_score, + ) + + if combined_logpi_m is not None: + modfied_combined_logpi_m = combined_logpi_m / torch.exp( + temperature + * (combined_factor_score - combined_factor_score.min()) + / (combined_factor_score.max() - combined_factor_score.min()) + ).unsqueeze(-1) + joint_sample_m, factor_idx_m = categorical_psample_wor( + modfied_combined_logpi_m, + self.decode_num_modes, + num_factor, + factor_mask=combined_factor_mask, + relevance_score=combined_factor_score, + ) + else: + joint_sample_m = None + + # Turn indices of factors of importance back to combined_logpi size, then sum over num factors to get actual one hot (assuming nonrepetitive factors) + + ( + lm_sample, + homo_sample, + factor_mask, + ) = self.restore_lm_homotopy_from_joint_sample( + joint_sample, + indices, + factor_idx, + lm_factor_mask, + N, + lane_mask_ext.shape[-1], + M_lm, + ) + if joint_sample_m is not None: + ( + lm_sample_m, + homo_sample_m, + factor_mask_m, + ) = self.restore_lm_homotopy_from_joint_sample( + joint_sample_m, + indices_m, + factor_idx_m, + lm_factor_mask, + N, + lane_mask_ext.shape[-1], + M_lm, + ) + lm_sample = torch.cat([lm_sample, lm_sample_m], 1) + homo_sample = torch.cat([homo_sample, homo_sample_m], 1) + num_samples += lm_sample_m.size(1) + factor_mask = factor_mask | factor_mask_m + + if self.null_lane_mode: + # remove the factor_mask for the null lane + factor_mask = factor_mask.view(B, N, -1) + factor_mask = torch.cat( + [factor_mask[..., :M], factor_mask[..., M + 1 :]], -1 + ).reshape(B, -1) + + if ( + GT_lane_mode is not None and GT_homotopy is not None + ): # If we have GT, add it as sample for training + # for all the irrelevant entries, make them exactly the same as GT + if self.classify_a2l_4all_lanes: + lm_sample = lm_sample * lm_factor_mask.bool().unsqueeze(1) + ( + GT_lane_mode * torch.logical_not(lm_factor_mask.bool()) + ).unsqueeze(1) + + homo_sample = homo_sample * homo_factor_mask.bool().unsqueeze(1) + ( + GT_homotopy * torch.logical_not(homo_factor_mask.bool()) + ).unsqueeze(1) + # mask modes that are the same as GT + + lm_GT_flag = (lm_sample == GT_lane_mode.unsqueeze(1)).all(dim=3) + lm_GT_flag = lm_GT_flag.masked_fill( + torch.logical_not(agent_mask[:, None]), 1 + ).all(dim=2) + homo_GT_flag = homo_sample == GT_homotopy.unsqueeze(1) + homo_GT_flag = ( + homo_GT_flag.masked_fill( + torch.logical_not(homo_factor_mask)[:, None], 1 + ) + .all(-1) + .all(-1) + ) + GT_flag = (lm_GT_flag & homo_GT_flag) * GT_batch_mask[:, None] + # make sure that at least one entry of GT_flag is true for all scene, we can then remove 1 sample from every scene + num_samples = num_samples - GT_flag.sum(-1).max().item() + + lm_sample = torch.stack( + [ + lm_sample[i][torch.where(~GT_flag[i])[0]][:num_samples] + for i in range(B) + ], + 0, + ) + homo_sample = torch.stack( + [ + homo_sample[i][torch.where(~GT_flag[i])[0]][:num_samples] + for i in range(B) + ], + 0, + ) + if GT_batch_mask.all(): + lm_sample = torch.cat([GT_lane_mode.unsqueeze(1), lm_sample], 1) + homo_sample = torch.cat([GT_homotopy.unsqueeze(1), homo_sample], 1) + else: + lm_sample_GT = torch.cat([GT_lane_mode.unsqueeze(1), lm_sample], 1) + homo_sample_GT = torch.cat([GT_homotopy.unsqueeze(1), homo_sample], 1) + # for instances where GT is not provided, simply pad the last index + lm_sample_nonGT = torch.cat([lm_sample, lm_sample[:, :1]], 1) + homo_sample_nonGT = torch.cat([homo_sample, homo_sample[:, :1]], 1) + lm_sample = lm_sample_GT * GT_batch_mask[ + :, None, None, None + ] + lm_sample_nonGT * torch.logical_not( + GT_batch_mask[:, None, None, None] + ) + homo_sample = homo_sample_GT * GT_batch_mask[ + :, None, None, None + ] + homo_sample_nonGT * torch.logical_not( + GT_batch_mask[:, None, None, None] + ) + num_samples += 1 + + lm_embedding = self.lm_embed(lm_sample[..., :M]) + + homo_embedding = self.homotopy_embed(homo_sample) + a2l_edge_with_mode = torch.cat( + [ + a2l_edge_pool.repeat_interleave(num_samples, 0), + lm_embedding.view(B * num_samples, N, M, -1), + ], + -1, + ) + a2a_edge_with_mode = torch.cat( + [ + a2a_edge_pool.repeat_interleave(num_samples, 0), + homo_embedding.view(B * num_samples, N, N, -1), + ], + -1, + ) + + JM_GNN_vars = { + TFvars.Agent_hist: agent_feat_pool.repeat_interleave( + num_samples, 0 + ).unsqueeze(-2), + TFvars.Lane: context_vars[TFvars.Lane].repeat_interleave(num_samples, 0), + GNNedges.Agenthist2Lane: a2l_edge_with_mode.unsqueeze(2), + GNNedges.Agenthist2Agenthist: a2a_edge_with_mode.unsqueeze(3), + } + JM_GNN_var_masks = { + TFvars.Agent_hist: var_masks[TFvars.Agent_hist] + .any(2, keepdim=True) + .repeat_interleave(num_samples, 0), + TFvars.Lane: var_masks[TFvars.Lane].repeat_interleave(num_samples, 0), + } + for i in range(len(self.JM_GNN)): + JM_GNN_vars = self.JM_GNN[i]( + JM_GNN_vars, JM_GNN_var_masks, cross_masks=dict() + ) + + lm_factor = self.JM_lane_mode_factor(JM_GNN_vars[GNNedges.Agenthist2Lane]).view( + B, num_samples, N, M + ) + + homotopy_factor = self.JM_homotopy_factor( + JM_GNN_vars[GNNedges.Agenthist2Agenthist] + ).view(B, num_samples, N, N) + + total_factor = torch.cat([lm_factor, homotopy_factor], -1).reshape( + B, num_samples, N * (M + N) + ) + total_factor_mask = ( + torch.cat([lm_factor_mask[..., :M], homo_factor_mask], -1).reshape(B, -1) + * factor_mask + ) + joint_logpi = ( + total_factor * (total_factor_mask * factor_mask).unsqueeze(1) + ).sum(-1) + + if joint_sample_m is not None: + num_sample_m = joint_sample_m.size(1) + ml_idx = joint_logpi.argmax(dim=1, keepdim=True) + # decode the GT, the most likely mode other than GT, and the modes sampled with modified probability + if GT_lane_mode is not None: + dec_idx = torch.cat( + [ + torch.zeros(B, 1, device=device), + ml_idx, + torch.arange( + num_samples - num_sample_m, num_samples, device=device + )[None].repeat_interleave(B, 0), + ], + dim=1, + ).type(torch.int64) + else: + dec_idx = torch.cat( + [ + ml_idx, + torch.arange( + num_samples - num_sample_m, num_samples, device=device + )[None].repeat_interleave(B, 0), + ], + dim=1, + ).type(torch.int64) + else: + dec_idx = joint_logpi.topk( + min(self.decode_num_modes, joint_logpi.shape[1]), dim=1 + ).indices + if GT_lane_mode is not None: + dec_idx = torch.cat( + [torch.zeros(B, 1, device=device), dec_idx], 1 + ).type(torch.int64) + if self.null_lane_mode: + dec_cond_lm = torch.gather( + lm_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, M + 1) + ) + else: + dec_cond_lm = torch.gather( + lm_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, M) + ) + dec_cond_homotopy = torch.gather( + homo_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, N) + ) + joint_mode_prob = torch.softmax(joint_logpi, -1) + dec_cond_prob = torch.gather(joint_mode_prob, 1, dec_idx) + # homotopy result test + # dec_cond_homotopy[0,1,0,2]=2 + # dec_cond_homotopy[0,1,2,0]=2 + return dict( + lane_mode_pred=lane_mode_pred, + homotopy_pred=homotopy_pred, + joint_pred=joint_logpi, + lane_mode_sample=lm_sample, + homotopy_sample=homo_sample, + dec_cond_lm=dec_cond_lm, + dec_cond_homotopy=dec_cond_homotopy, + dec_cond_prob=dec_cond_prob, + ) + + def decode_trajectory_AR( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ): + ARS = self.AR_step_size + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + M = aux_xs[TFvars.Lane].shape[1] + device = enc_vars[TFvars.Agent_hist].device + raw_agent_hist = raw_vars[TFvars.Agent_hist] # [B,N,Th,F] + raw_agent_fut = list( + raw_agent_hist[:, :, -1:] + .repeat_interleave(1 + self.dec_T, 2) + .split(1, dim=2) + ) + state_out = {k: list() for k in self.dyn} + input_out = {k: list() for k in self.dyn} + Tf = self.dec_T + curr_yaw = ratan2( + aux_xs[TFvars.Agent_future][:, :, 0, 3:4], + aux_xs[TFvars.Agent_future][:, :, 0, 4:5], + ) + curr_yaw = curr_yaw * var_masks[TFvars.Agent_future][:, :, 0:1] + ar_mask = torch.zeros([1, 1, Tf + 1], dtype=bool, device=device) + xt = x0 + dec_kwargs_ar = dict( + aux_xs=aux_xs, + cross_masks=cross_masks, + frame_indices=frame_indices, + ) + lane_xysc = aux_xs[TFvars.Lane].reshape(B, M, -1, 4) + future_xyvsc = list(torch.split(aux_xs[TFvars.Agent_future][..., :5], 1, dim=2)) + for i in range(Tf // ARS): + dec_raw_vars = { + TFvars.Agent_future: torch.cat(raw_agent_fut, 2), + TFvars.Lane: raw_vars[TFvars.Lane], + } + + ar_mask[:, :, : i * ARS + 1] = 1 + x_mask = var_masks[TFvars.Agent_future] * ar_mask + + cross_masks = dict() + Agent_fut_cross_mask = var_masks[TFvars.Agent_future].unsqueeze( + -1 + ) * var_masks[TFvars.Agent_future].unsqueeze(-2) + Agent_fut_cross_mask = torch.tril(Agent_fut_cross_mask) * ar_mask.unsqueeze( + -2 + ) + + cross_masks[ + TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.T + ] = Agent_fut_cross_mask + + dec_kwargs_ar["var_masks"] = var_masks + dec_kwargs_ar["cross_masks"] = cross_masks + future_xyvsc_tensor = torch.cat(future_xyvsc, dim=2) + future_aux = torch.cat( + [future_xyvsc_tensor, aux_xs[TFvars.Agent_future][..., 5:]], -1 + ) + aux_xs[TFvars.Agent_future] = future_aux + future_xysc = future_xyvsc_tensor[..., [0, 1, 3, 4]] + lane_margins = ( + self.fut_lane_relation.get_all_margins( + TensorUtils.join_dimensions(future_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + ) + .reshape(B, N, M, 1 + Tf, -1) + .clip(-20, 20) + ) + lane_margins.masked_fill_( + ~( + x_mask[:, :, None, :, None] + * var_masks[TFvars.Lane][:, None, :, None, None] + ).bool(), + 0, + ) + # mask out all unselected lanes + lane_margins.masked_fill_( + lane_mode[..., self.fut_lane_relation.NOTON, None, None].bool(), 0 + ) + agent_edge = future_xysc[..., :2].unsqueeze(2) - future_xysc[ + ..., :2 + ].unsqueeze(1) + rel_angle = ratan2(agent_edge[..., [1]], agent_edge[..., [0]]) + rel_angle_offset = torch.cat( + [rel_angle[..., :1, :], rel_angle[..., :-1, :]], -2 + ) + + angle_diff = (rel_angle - rel_angle_offset).cumsum(-2) + homotopy_margins = torch.cat( + [ + HOMOTOPY_THRESHOLD - angle_diff.abs(), + -angle_diff - HOMOTOPY_THRESHOLD, + angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + homotopy_margins.masked_fill_( + torch.logical_not(x_mask.unsqueeze(1) * x_mask.unsqueeze(2)).unsqueeze( + -1 + ), + 0, + ) + aux_dec_edges = { + GNNedges.Agentfuture2Agentfuture: torch.cat( + [ + homotopy.unsqueeze(3).repeat_interleave(1 + self.dec_T, 3), + homotopy_margins, + ], + -1, + ), + GNNedges.Agentfuture2Lane: torch.cat( + [ + lane_mode.unsqueeze(2).repeat_interleave(1 + self.dec_T, 2), + lane_margins.transpose(2, 3), + ], + -1, + ), + } + dec_vars = self.embed_raw( + dec_raw_vars, + {TFvars.Agent_future: future_aux, TFvars.Lane: aux_xs[TFvars.Lane]}, + aux_dec_edges, + ) + vars = {**enc_vars, **dec_vars} + # update edges due to change of aux_xs + dec_agent_aux = future_aux.permute(0, 2, 1, 3).reshape(B * (Tf + 1), N, -1) + + a2a_edge = torch.cat( + [ + self.edge_func["a2a"](dec_agent_aux, dec_agent_aux), + homotopy.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + homotopy_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + a2l_edge = torch.cat( + [ + self.edge_func["a2l"]( + dec_agent_aux, aux_xs[TFvars.Lane].repeat_interleave(Tf + 1, 0) + ), + lane_mode.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + lane_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + edges = { + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A): a2a_edge, + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): [ + a2l_edge + ], + } + dec_kwargs_ar["edges"] = edges + vars = self.decoder(vars, dec_kwargs_ar) + out = dict() + + if self.AR_update_mode == "step": + if self.arch == "lstm": + raise ValueError("LSTM arch is not working!") + + elif self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B * N, Tf + 1, -1)[ + :, i * ARS : (i + 1) * ARS + ] + ).view(B, N, ARS, -1) + else: + raise NotImplementedError + + elif self.AR_update_mode == "all": + if self.arch == "lstm": + raise ValueError("LSTM arch is not working!") + elif self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B * N, Tf + 1, -1)[ + :, : (i + 1) * ARS + ] + ).view(B, N, ARS * (i + 1), -1) + else: + raise NotImplementedError + + agent_xyvsc_by_type = dict() + agent_acce = dict() + agent_r = dict() + for k, dyn in self.dyn.items(): + if dyn is None: + # output assumed to take form [x,y,v,h] + if center_from_agents is None: + agent_xyvsc_by_type[k] = torch.cat( + [ + out[k][..., :3], + torch.sin(out[k][..., 3:]), + torch.cos(out[k][..., 3:]), + ], + -1, + ) + + else: + # transform to global coordinate + xy_local = out[k][..., :2] + agent_v = out[k][..., 2:3] + h_local = out[k][..., 3:4] + xy_global = GeoUtils.batch_nd_transform_points( + xy_local, center_from_agents.unsqueeze(2) + ) + h_global = h_local + curr_yaw.unsqueeze(2) + agent_xyvsc_by_type[k] = torch.cat( + [ + xy_global, + agent_v, + torch.sin(h_global), + torch.cos(h_global), + ], + -1, + ) + + agent_v = out[k][..., 2:3] + if self.AR_update_mode == "step": + agent_v_pre = torch.cat( + future_xyvsc[i * ARS : (i + 1) * ARS], 2 + )[..., 2:3] + # here we assume input is always 2 dimensional + input_out[k].append(torch.zeros_like(out[k][..., :2])) + elif self.AR_update_mode == "all": + agent_v_pre = torch.cat(future_xyvsc[: (i + 1) * ARS], 2)[ + ..., 2:3 + ] + + agent_acce[k] = (agent_v - agent_v_pre) / self.dt + agent_r[k] = torch.zeros_like(agent_v) + + else: + if isinstance(dyn, Unicycle): + xseq = dyn.forward_dynamics( + xt[k], out[k], mode="chain", bound=False + ) + + if self.AR_update_mode == "step": + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + state_out[k].append(xseq) + xt[k] = xseq[:, :, -1] + elif self.AR_update_mode == "all": + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + + elif isinstance(dyn, Unicycle_xyvsc): + xseq = dyn.forward_dynamics(xt[k], out[k], bound=False) + + agent_xyvsc_by_type[k] = xseq + if self.AR_update_mode == "step": + xt[k] = xseq[:, :, -1] + state_out[k].append(xseq) + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + else: + raise NotImplementedError + input_out[k].append(out[k]) + agent_xyvsc_combined = sum( + [ + agent_mask[k][..., None, None] * agent_xyvsc_by_type[k] + for k in agent_xyvsc_by_type + ] + ) + + acce_combined = sum( + [agent_mask[k][..., None, None] * agent_acce[k] for k in agent_acce] + ) + r_combined = sum( + [agent_mask[k][..., None, None] * agent_r[k] for k in agent_r] + ) + agent_v = agent_xyvsc_combined[..., 2:3] + # update the agent future aux + + # update the agent future raw [v,acce,r,static_features] + + agent_vu = torch.cat([agent_v, acce_combined, r_combined], -1) + if self.AR_update_mode == "step": + future_xyvsc[i * ARS + 1 : (i + 1) * ARS + 1] = list( + torch.split(agent_xyvsc_combined, 1, dim=2) + ) + for j in range(ARS): + raw_agent_fut[i * ARS + 1 + j] = torch.cat( + [ + agent_vu[:, :, j : j + 1], + raw_agent_fut[i * ARS + 1 + j][..., 3:], + ], + -1, + ) + elif self.AR_update_mode == "all": + future_xyvsc[: (i + 1) * ARS + 1] = list( + torch.split(agent_xyvsc_combined, 1, dim=2) + ) + for j in range(ARS * (i + 1)): + raw_agent_fut[j] = torch.cat( + [agent_vu[:, :, j : j + 1], raw_agent_fut[j][..., 3:]], -1 + ) + + if self.AR_update_mode == "step": + future_xyvsc = torch.cat(future_xyvsc, dim=2) + input_out = { + k: torch.cat(input_type, dim=2) if len(input_type) > 0 else None + for k, input_type in input_out.items() + } + state_out = { + k: torch.cat(state_type, dim=2) + if len(state_type) > 0 + else future_xyvsc[..., 1:, :] + for k, state_type in state_out.items() + } + elif self.AR_update_mode == "all": + future_xyvsc = agent_xyvsc_combined + for k, v in out.items(): + if v is not None: + input_out[k] = v + else: + input_out[k] = torch.zeros_like(out[k][..., :2]) + + trajectories = torch.cat( + [ + future_xyvsc[:, :, 1:, :2], + ratan2(future_xyvsc[:, :, 1:, 3:4], future_xyvsc[:, :, 1:, 4:5]), + ], + -1, + ) + + input_violation = dict() + jerk = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + xl = torch.cat([x0[k].unsqueeze(-2), state_out[k][..., :-1, :]], -2) + input_violation[k] = ( + dyn.get_input_violation(xl, input_out[k]) + * agent_mask[k][..., None, None] + ) + inputs_extended = torch.cat([u0[k].unsqueeze(-2), input_out[k]], -2) + jerk[k] = ( + (inputs_extended[..., 1:, :] - inputs_extended[..., :-1, :]) + / self.dt + ) * agent_mask[k][..., None, None] + return trajectories, input_out, state_out, input_violation, jerk + + def decode_trajectory_OS( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ): + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + M = aux_xs[TFvars.Lane].shape[1] + device = enc_vars[TFvars.Agent_hist].device + raw_agent_hist = raw_vars[TFvars.Agent_hist] # [B,N,Th,F] + raw_agent_fut = raw_agent_hist[:, :, -1:].repeat_interleave(1 + self.dec_T, 2) + Tf = self.dec_T + curr_yaw = ratan2( + aux_xs[TFvars.Agent_future][:, :, 0, 3:4], + aux_xs[TFvars.Agent_future][:, :, 0, 4:5], + ) + curr_yaw = curr_yaw * var_masks[TFvars.Agent_future][:, :, 0:1] + dec_kwargs = dict( + aux_xs=aux_xs, + cross_masks=cross_masks, + frame_indices=frame_indices, + ) + lane_xysc = aux_xs[TFvars.Lane].reshape(B, M, -1, 4) + future_xyvsc = aux_xs[TFvars.Agent_future][..., :5] + input_out = dict() + state_out = dict() + for i in range(self.dec_rounds): + dec_raw_vars = { + TFvars.Agent_future: raw_agent_fut, + TFvars.Lane: raw_vars[TFvars.Lane], + } + + x_mask = var_masks[TFvars.Agent_future] + + cross_masks = dict() + Agent_fut_cross_mask = var_masks[TFvars.Agent_future].unsqueeze( + -1 + ) * var_masks[TFvars.Agent_future].unsqueeze(-2) + + cross_masks[ + TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.T + ] = torch.tril(Agent_fut_cross_mask) + + dec_kwargs["var_masks"] = var_masks + dec_kwargs["cross_masks"] = cross_masks + future_aux = torch.cat( + [future_xyvsc, aux_xs[TFvars.Agent_future][..., 5:]], -1 + ) + aux_xs[TFvars.Agent_future] = future_aux + future_xysc = future_xyvsc[..., [0, 1, 3, 4]] + lane_margins = ( + self.fut_lane_relation.get_all_margins( + TensorUtils.join_dimensions(future_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + ) + .reshape(B, N, M, 1 + Tf, -1) + .clip(-20, 20) + ) + lane_margins.masked_fill_( + ~( + x_mask[:, :, None, :, None] + * var_masks[TFvars.Lane][:, None, :, None, None] + ).bool(), + 0, + ) + # mask out all unselected lanes + lane_margins.masked_fill_( + lane_mode[..., self.fut_lane_relation.NOTON, None, None].bool(), 0 + ) + agent_edge = future_xysc[..., :2].unsqueeze(2) - future_xysc[ + ..., :2 + ].unsqueeze(1) + rel_angle = ratan2(agent_edge[..., [1]], agent_edge[..., [0]]) + rel_angle_offset = torch.cat( + [rel_angle[..., :1, :], rel_angle[..., :-1, :]], -2 + ) + + angle_diff = (rel_angle - rel_angle_offset).cumsum(-2) + homotopy_margins = torch.cat( + [ + HOMOTOPY_THRESHOLD - angle_diff.abs(), + -angle_diff - HOMOTOPY_THRESHOLD, + angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + homotopy_margins.masked_fill_( + torch.logical_not(x_mask.unsqueeze(1) * x_mask.unsqueeze(2)).unsqueeze( + -1 + ), + 0, + ) + aux_dec_edges = { + GNNedges.Agentfuture2Agentfuture: torch.cat( + [ + homotopy.unsqueeze(3).repeat_interleave(1 + self.dec_T, 3), + homotopy_margins, + ], + -1, + ), + GNNedges.Agentfuture2Lane: torch.cat( + [ + lane_mode.unsqueeze(2).repeat_interleave(1 + self.dec_T, 2), + lane_margins.transpose(2, 3), + ], + -1, + ), + } + dec_vars = self.embed_raw( + dec_raw_vars, + {TFvars.Agent_future: future_aux, TFvars.Lane: aux_xs[TFvars.Lane]}, + aux_dec_edges, + ) + vars = {**enc_vars, **dec_vars} + # update edges due to change of aux_xs + dec_agent_aux = future_aux.permute(0, 2, 1, 3).reshape(B * (Tf + 1), N, -1) + + a2a_edge = torch.cat( + [ + self.edge_func["a2a"](dec_agent_aux, dec_agent_aux), + homotopy.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + homotopy_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + a2l_edge = torch.cat( + [ + self.edge_func["a2l"]( + dec_agent_aux, aux_xs[TFvars.Lane].repeat_interleave(Tf + 1, 0) + ), + lane_mode.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + lane_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + edges = { + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A): a2a_edge, + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): [ + a2l_edge + ], + } + dec_kwargs["edges"] = edges + vars = self.decoder(vars, dec_kwargs) + out = dict() + + if self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B, N, Tf + 1, -1) + ) + else: + raise NotImplementedError + + agent_xyvsc_by_type = dict() + agent_acce = dict() + agent_r = dict() + for k, dyn in self.dyn.items(): + if dyn is None: + # output assumed to take form [x,y,v,h] + if center_from_agents is None: + agent_xyvsc_by_type[k] = torch.cat( + [ + out[k][..., :3], + torch.sin(out[k][..., 3:]), + torch.cos(out[k][..., 3:]), + ], + -1, + )[..., 1:, :] + + else: + # transform to global coordinate + xy_local = out[k][..., :2] + agent_v = out[k][..., 2:3] + h_local = out[k][..., 3:4] + xy_global = GeoUtils.batch_nd_transform_points( + xy_local, center_from_agents.unsqueeze(2) + ) + h_global = h_local + curr_yaw.unsqueeze(2) + agent_xyvsc_by_type[k] = torch.cat( + [ + xy_global, + agent_v, + torch.sin(h_global), + torch.cos(h_global), + ], + -1, + )[..., 1:, :] + + agent_v = out[k][..., 2:3] + + agent_acce[k] = ( + out[k][..., 1:, 2:3] - out[k][..., :Tf, 2:3] + ) / self.dt + agent_acce[k] = torch.cat( + [agent_acce[k], agent_acce[k][..., -1:, :]], -2 + ) + agent_r[k] = torch.zeros_like(agent_v) + + else: + if isinstance(dyn, Unicycle): + xseq = dyn.forward_dynamics( + x0[k], out[k][..., :Tf, :], mode="chain", bound=False + ) + + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + + elif isinstance(dyn, Unicycle_xyvsc): + xseq = dyn.forward_dynamics( + x0[k], out[k][..., :Tf, :], bound=False + ) + + agent_xyvsc_by_type[k] = xseq + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + else: + raise NotImplementedError + state_out[k] = xseq + future_xyvsc[..., 1:, :] = sum( + [ + agent_mask[k][..., None, None] * agent_xyvsc_by_type[k] + for k in agent_xyvsc_by_type + ] + ) + + acce_combined = sum( + [agent_mask[k][..., None, None] * agent_acce[k] for k in agent_acce] + ) + r_combined = sum( + [agent_mask[k][..., None, None] * agent_r[k] for k in agent_r] + ) + agent_v = future_xyvsc[..., 2:3] + # update the agent future aux + + # update the agent future raw [v,acce,r,static_features] + + agent_vu = torch.cat([agent_v, acce_combined, r_combined], -1) + raw_agent_fut = torch.cat([agent_vu, raw_agent_fut[..., 3:]], -1) + for k, v in out.items(): + if v is not None: + input_out[k] = v[..., :Tf, :] + else: + input_out[k] = torch.zeros_like(out[k][..., :Tf, :2]) + + trajectories = torch.cat( + [ + future_xyvsc[:, :, 1:, :2], + ratan2(future_xyvsc[:, :, 1:, 3:4], future_xyvsc[:, :, 1:, 4:5]), + ], + -1, + ) + + input_violation = dict() + jerk = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + xl = torch.cat([x0[k].unsqueeze(-2), state_out[k][..., :-1, :]], -2) + input_violation[k] = ( + dyn.get_input_violation(xl, input_out[k][..., :Tf, :]) + * agent_mask[k][..., None, None] + ) + inputs_extended = torch.cat([u0[k].unsqueeze(-2), input_out[k]], -2) + jerk[k] = ( + (inputs_extended[..., 1:, :] - inputs_extended[..., :-1, :]) + / self.dt + ) * agent_mask[k][..., None, None] + return trajectories, input_out, state_out, input_violation, jerk + + def decode_trajectory( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + lane_mode, + homotopy, + x0, + u0, + center_from_agents, + num_samples=None, + ): + device = enc_vars[TFvars.Agent_hist].device + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + assert AgentType.VEHICLE in self.dyn + agent_mask = dict() + for k, v in self.dyn.items(): + if k == AgentType.VEHICLE: + # this is the default mode + other_types = [k.value for k in self.dyn if k != AgentType.VEHICLE] + + agent_mask[k] = torch.logical_not( + agent_type[..., other_types].any(-1) + ) * agent_type.any(-1) + else: + agent_mask[k] = agent_type[..., k.value] + if lane_mode.ndim == 4: + # sample mode + sample_mode = True + num_samples = homotopy.shape[1] + lane_mode, homotopy = TensorUtils.join_dimensions( + (lane_mode, homotopy), 0, 2 + ) + B0 = B + B = B0 * num_samples + enc_vars = { + k: v.repeat_interleave(num_samples, 0) for k, v in enc_vars.items() + } + raw_vars = { + k: v.repeat_interleave(num_samples, 0) for k, v in raw_vars.items() + } + aux_xs = {k: v.repeat_interleave(num_samples, 0) for k, v in aux_xs.items()} + var_masks = { + k: v.repeat_interleave(num_samples, 0) for k, v in var_masks.items() + } + cross_masks = { + k: v.repeat_interleave(num_samples, 0) for k, v in cross_masks.items() + } + frame_indices = { + k: v.repeat_interleave(num_samples, 0) for k, v in frame_indices.items() + } + for k, v in x0.items(): + if v is not None: + x0[k] = v.repeat_interleave(num_samples, 0) + for k, v in u0.items(): + if v is not None: + u0[k] = v.repeat_interleave(num_samples, 0) + center_from_agents = center_from_agents.repeat_interleave(num_samples, 0) + for k, v in agent_mask.items(): + agent_mask[k] = v.repeat_interleave(num_samples, 0) + + else: + sample_mode = False + + lane_mode = F.one_hot(lane_mode, len(self.fut_lane_relation)).float() + if self.null_lane_mode: + lane_mode = lane_mode[..., :-1, :] + homotopy = F.one_hot(homotopy, len(HomotopyType)).float() + + Tf = self.dec_T + # create agent_future from agent_hist for auto-regressive decoding + + var_masks[TFvars.Agent_future] = ( + var_masks[TFvars.Agent_hist] + .any(2, keepdim=True) + .repeat_interleave(1 + Tf, 2) + ) + aux_xs[TFvars.Agent_future] = aux_xs[TFvars.Agent_hist][ + :, :, -1: + ].repeat_interleave(1 + Tf, 2) + + frame_indices[TFvars.Agent_future] = Th + torch.arange( + 1 + self.dec_T, device=device + )[None, None, :].repeat(B, N, 1) + + future_xyvsc = list(torch.split(aux_xs[TFvars.Agent_future][..., :5], 1, dim=2)) + + if self.AR_update_mode is not None: + ( + trajectories, + input_out, + state_out, + input_violation, + jerk, + ) = self.decode_trajectory_AR( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ) + else: + ( + trajectories, + input_out, + state_out, + input_violation, + jerk, + ) = self.decode_trajectory_OS( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ) + + if sample_mode: + trajectories = trajectories.view(B0, num_samples, N, Tf, -1) + state_out = { + k: v.view(B0, num_samples, N, Tf, -1) if v is not None else None + for k, v in state_out.items() + } + input_out = { + k: v.view(B0, num_samples, N, Tf, -1) if v is not None else None + for k, v in input_out.items() + } + agent_mask = {k: v.view(B0, num_samples, N) for k, v in agent_mask.items()} + input_violation = { + k: v.view(B0, num_samples, N, Tf, -1) + for k, v in input_violation.items() + } + jerk = {k: v.view(B0, num_samples, N, Tf, -1) for k, v in jerk.items()} + + return dict( + trajectories=trajectories, + states=state_out, + inputs=input_out, + input_violation=input_violation, + jerk=jerk, + type_mask=agent_mask, + future_xyvsc=future_xyvsc, + ) + + def restore_lm_homotopy_from_joint_sample( + self, joint_sample, indices, factor_idx, lm_factor_mask, N, M, M_lm + ): + device = joint_sample.device + B, num_samples = joint_sample.shape[:2] + factor_mask = F.one_hot(factor_idx, N * (N + M_lm)).sum( + -2 + ) # (B, |combined_log_pi|) + if not self.classify_a2l_4all_lanes: + factor_mask = torch.cat( + [ + torch.ones(B, N * M, device=device, dtype=torch.int), + factor_mask[:, N * M_lm :], + ], + -1, + ) + candidate_flag = torch.zeros(B, N, M, dtype=torch.bool, device=device) + for i in range(M_lm): + candidate_flag.scatter_(-1, indices[:, :, i], 1) # Back to lanes + lm_factor_mask = lm_factor_mask * candidate_flag + + if not self.classify_a2l_4all_lanes: + lm_sample_per_mode = joint_sample[..., : N * M_lm].view( + B, -1, N, M_lm + ) # First N*M_lm are mode contributions + lm_sample = torch.zeros( + [B, joint_sample.size(1), N, M], dtype=torch.long, device=device + ) + + i = 0 + lm_sample_per_mode = lm_sample_per_mode.clip(0, indices.shape[-1] - 1) + for mode in self.fut_lane_relation: + if mode == self.fut_lane_relation.NOTON: + continue + # restore from the important sampling of of lane segments + restored_lm_samples_i = torch.gather( + indices[:, :, i].unsqueeze(1).repeat_interleave(num_samples, 1), + -1, + lm_sample_per_mode[..., i : i + 1], + ).squeeze(-1) + # value of the lane mode + mode_enum = i + 1 if mode > self.fut_lane_relation.NOTON else i + lm_sample = lm_sample + (lm_sample == 0) * ( + torch.scatter( + lm_sample, -1, restored_lm_samples_i.unsqueeze(-1), mode_enum + ) + ) + i += 1 + else: + lm_sample = joint_sample[..., : N * M].view( + B, -1, N, M + ) # .clip(0,len(self.fut_lane_relation)-1) + lm_sample.masked_fill_(~lm_factor_mask.bool().unsqueeze(1), 0) + homo_sample = joint_sample[..., N * M_lm :].view(B, -1, N, N) + homo_sample = HomotopyType.enforce_symmetry(homo_sample) + return lm_sample, homo_sample, factor_mask + + def forward( + self, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + enc_edges=dict(), + dec_edges=dict(), + GT_lane_mode=None, + GT_homotopy=None, + center_from_agents=None, + num_samples=None, + ): + enc_raw_vars = { + k: v + for k, v in raw_vars.items() + if k in self.enc_vars or k in self.enc_edges + } + enc_vars = self.embed_raw(enc_raw_vars, aux_xs) + + x0 = dict() + u0 = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + # get x0 and u0 from aux_xs + if isinstance(dyn, Unicycle_xyvsc): + x0[k] = aux_xs[TFvars.Agent_hist][:, :, -1, :5] + xprev = aux_xs[TFvars.Agent_hist][:, :, -2, :5] + u0[k] = dyn.inverse_dyn(x0[k], xprev, self.dt) + elif isinstance(dyn, Unicycle): + xyvsc = aux_xs[TFvars.Agent_hist][:, :, -2:, :5] + h = ratan2(xyvsc[..., 3:4], xyvsc[..., 4:5]) + x0[k] = torch.cat([xyvsc[..., -1, :3], h[..., -1, :]], -1) + xprev = torch.cat([xyvsc[..., -2, :3], h[..., -2, :]], -1) + u0[k] = dyn.inverse_dyn(x0[k], xprev, self.dt) + else: + raise NotImplementedError + else: + x0[k] = None + u0[k] = None + + enc_kwargs = dict( + aux_xs=aux_xs, + var_masks=var_masks, + cross_masks=cross_masks, + frame_indices=frame_indices, + edges=enc_edges, + ) + enc_vars = self.encoder(enc_vars, enc_kwargs) + mode_pred = self.predict_from_context( + enc_vars, + aux_xs, + frame_indices, + var_masks, + cross_masks, + enc_edges, + GT_lane_mode, + GT_homotopy, + num_samples=num_samples, + ) + + # decode trajectory predictions + # prepare modes + + if GT_lane_mode is not None and GT_homotopy is not None: + # train mode + if self.decode_num_modes == 1: + lane_mode = GT_lane_mode + homotopy = GT_homotopy + mode_pred["dec_cond_lm"] = GT_lane_mode.unsqueeze(1) + mode_pred["dec_cond_homotopy"] = GT_homotopy.unsqueeze(1) + mode_pred["dec_cond_prob"] = torch.ones( + [GT_lane_mode.shape[0], 1], device=GT_lane_mode.device + ) + else: + lane_mode = mode_pred["dec_cond_lm"] + homotopy = mode_pred["dec_cond_homotopy"] + + else: + # infer mode + lane_mode = mode_pred["lane_mode_sample"] + homotopy = mode_pred["homotopy_sample"] + + dec_vars = self.decode_trajectory( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + lane_mode, + homotopy, + x0=x0, + u0=u0, + center_from_agents=center_from_agents, + ) + return dec_vars, mode_pred diff --git a/diffstack/models/RPE_simple.py b/diffstack/models/RPE_simple.py new file mode 100644 index 0000000..0bd885e --- /dev/null +++ b/diffstack/models/RPE_simple.py @@ -0,0 +1,441 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +from diffstack.configs.config import Dict +from diffstack.models.TypeTransformer import NewGELU, zero_module +from diffstack.utils.model_utils import * + +""" +modified from RPE.py with the following changes: +1. removed diffusion time embedding +2. removed the D dimension +3. reversed the order of T and C +4. changed attn_mask to be [B,T,T] instead of [B,T] +5. removed the residual connection, instead, output the residual itself +6. removed the batchnorm +""" +class sRPENet(nn.Module): + def __init__(self, n_embd, num_heads): + super().__init__() + self.embed_distances = nn.Linear(3, n_embd) + + self.gelu = NewGELU() + self.out = nn.Linear(n_embd, n_embd) + self.out.weight.data *= 0. + self.out.bias.data *= 0. + self.n_embd = n_embd + self.num_heads = num_heads + + def forward(self, relative_distances): + distance_embs = torch.stack( + [torch.log(1+(relative_distances).clamp(min=0)), + torch.log(1+(-relative_distances).clamp(min=0)), + (relative_distances == 0).float()], + dim=-1 + ) # BxTxTx3 + if self.embed_distances.weight.dtype==torch.float16: + distance_embs = distance_embs.half() + emb = self.embed_distances(distance_embs) + return self.out(self.gelu(emb)).view(*relative_distances.shape, self.num_heads, self.n_embd//self.num_heads) + +class ProdPENet(nn.Module): + # embed the tuple of two positions [P1,P2] into positional embedding + def __init__(self, n_embd, num_heads): + super().__init__() + self.embed_distances = nn.Linear(5, n_embd) + + self.gelu = NewGELU() + self.out = nn.Linear(n_embd, n_embd) + self.out.weight.data *= 0. + self.out.bias.data *= 0. + self.n_embd = n_embd + self.num_heads = num_heads + + def forward(self, relative_distances,P1,P2): + distance_embs = torch.stack( + [torch.log(1+(relative_distances).clamp(min=0)), + torch.log(1+(-relative_distances).clamp(min=0)), + (relative_distances == 0).float(), + torch.log(1+P1), + torch.log(1+P2)], + dim=-1 + ) # BxTxTx3 + if self.embed_distances.weight.dtype==torch.float16: + distance_embs = distance_embs.half() + emb = self.embed_distances(distance_embs) + return self.out(self.gelu(emb)).view(*relative_distances.shape, self.num_heads, self.n_embd//self.num_heads) + +class sRPE(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py + def __init__(self, n_embd, num_heads, use_rpe_net=False): + """ This module handles the relative positional encoding. + Args: + channels (int): Number of input channels. + num_heads (int): Number of attention heads. + """ + super().__init__() + self.num_heads = num_heads + self.head_dim = n_embd // self.num_heads + self.use_rpe_net = use_rpe_net + if use_rpe_net: + self.rpe_net = sRPENet(n_embd, num_heads) + else: + self.prod_pe_net = ProdPENet(n_embd, num_heads) + # raise NotImplementedError + # self.lookup_table_weight = nn.Parameter( + # torch.zeros(2 * self.beta + 1, + # self.num_heads, + # self.head_dim)) + + def get_R(self, pairwise_distances,P1=None,P2=None): + if self.use_rpe_net: + return self.rpe_net(pairwise_distances) + else: + return self.prod_pe_net(pairwise_distances,P1,P2) + # return self.lookup_table_weight[pairwise_distances] # BxTxTxHx(C/H) + + def forward(self, x, pairwise_distances, mode,P1=None,P2=None): + if mode == "qk": + return self.forward_qk(x, pairwise_distances,P1=P1,P2=P2) + elif mode == "v": + return self.forward_v(x, pairwise_distances,P1=P1,P2=P2) + else: + raise ValueError(f"Unexpected RPE attention mode: {mode}") + + def forward_qk(self, qk, pairwise_distances,P1=None,P2=None): + # qv is either of q or k and has shape BxHxTx(C/H) + # Output shape should be # BxHxTxT + R = self.get_R(pairwise_distances,P1=P1,P2=P2) + if qk.ndim==4: + return torch.einsum( # See Eq. 16 in https://arxiv.org/pdf/2107.14222.pdf + "bhtf,btshf->bhts", qk, R # BxHxTxT + ) + elif qk.ndim==5: + return torch.einsum( # See Eq. 16 in https://arxiv.org/pdf/2107.14222.pdf + "bhtsf,btshf->bhts", qk, R # BxHxTxT + ) + + def forward_v(self, attn, pairwise_distances,P1=None,P2=None): + # attn has shape BxHxT1xT2 + # Output shape should be # BxHxT1x(C/H) + R = self.get_R(pairwise_distances,P1=P1,P2=P2) + return torch.einsum( # See Eq. 16ish in https://arxiv.org/pdf/2107.14222.pdf + "bhts,btshf->bhtf", attn, R # BxHxTxT + ) + + def forward_safe_qk(self, x, pairwise_distances): + R = self.get_R(pairwise_distances) + B, T, _, H, F = R.shape + res = x.new_zeros(B, H, T, T) # attn shape + for b in range(B): + for h in range(H): + for i in range(T): + for j in range(T): + res[b, h, i, j] = x[b, h, i].dot(R[b, i, j, h]) + return res + +class sRPEAttention(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py#L42 + def __init__(self, n_embd, num_heads, use_checkpoint=False,use_rpe_net=None, + use_rpe_q=True, use_rpe_k=True, use_rpe_v=True, + ): + super().__init__() + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qkv = nn.Linear(n_embd, n_embd * 3) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + if use_rpe_q or use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_q = make_rpe_func() if use_rpe_q else None + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x, attn_mask, frame_indices, attn_weights_list=None): + out, attn = checkpoint(self._forward, (x, attn_mask, frame_indices), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T, C = x.shape + attn_weights_list.append(attn.detach().view(B, -1, T, T).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x, attn_mask, frame_indices): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, C // self.num_heads) + qkv = torch.einsum("BTtHF -> tBHTF", qkv) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + # q, k, v shapes: BxHxTx(C/H) + q *= self.scale + attn = (q @ k.transpose(-2, -1)) # BxDxHxTxT + if self.rpe_q is not None or self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices.unsqueeze(-1) - frame_indices.unsqueeze(-2)) # BxTxT + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk") + # relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale, pairwise_distances, mode="qk") + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, 1, T, T) # BxDxHxTxT + w.masked_fill_((attn_mask==0).all(-1).view(B,1,1,T,1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v") + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T, C) + out = self.proj_out(out) + return out, attn + +class AuxPEAttention(nn.Module): + def __init__(self, n_embd, num_heads, max_len = 100,attn_pdrop=0.0,resid_pdrop = 0.0,aux_vardim=0): + super().__init__() + + self.num_heads = num_heads + self.PE_q = nn.Embedding(max_len, n_embd) + self.PE_k = nn.Embedding(max_len, n_embd) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.aux_vardim = aux_vardim + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+aux_vardim, n_embd) + self.vnet = nn.Linear(n_embd+aux_vardim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + def forward(self, x, aux_x, attn_mask, frame_indices): + + B, T, C = x.shape + if aux_x is not None: + x_aug = torch.cat([x, aux_x], dim=2) + else: + x_aug = x + q,k,v = self.qnet(x).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2),\ + self.knet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2),\ + self.vnet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + + # q, k, v shapes: BxHxTx(C/H) + q = q + self.PE_q(frame_indices).view(B,T,self.num_heads,C//self.num_heads) + k = k + self.PE_k(frame_indices).view(B,T,self.num_heads,C//self.num_heads) + + out = F.scaled_dot_product_attention(q,k,v,attn_mask[:,None].repeat_interleave(self.num_heads,1),dropout=0.0) + out = out.nan_to_num(0.0) + out = out.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + out = self.resid_dropout(self.proj_out(out)) + return out + + + +class sAuxRPEAttention(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py#L42 + def __init__(self, n_embd, num_heads, aux_vardim, use_checkpoint=False, + use_rpe_net=None, use_rpe_q=True, use_rpe_k=True, use_rpe_v=True + ): + super().__init__() + assert aux_vardim>0 + self.aux_vardim = aux_vardim + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+aux_vardim, n_embd) + self.vnet = nn.Linear(n_embd+aux_vardim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + if use_rpe_q or use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_q = make_rpe_func() if use_rpe_q else None + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x, aux_x, attn_mask, frame_indices, attn_weights_list=None,**kwargs): + out, attn = checkpoint(self._forward, (x, aux_x, attn_mask, frame_indices), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T, C = x.shape + attn_weights_list.append(attn.detach().view(B, -1, T, T).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x, aux_x, attn_mask, frame_indices): + B, T, C = x.shape + x_aug = torch.cat([x, aux_x], dim=2) + + q = self.qnet(x).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + k = self.knet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + v = self.vnet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + # q, k, v shapes: BxHxTx(C/H) + q *= self.scale + attn = (q @ k.transpose(-2, -1)) # BxHxTxT + if self.rpe_q is not None or self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices.unsqueeze(-1) - frame_indices.unsqueeze(-2)) # BxTxT + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk") + # relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale, pairwise_distances, mode="qk") + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, T, T) # BxHxTxT + w.masked_fill_((attn_mask[:,None]==0).all(-1).unsqueeze(-1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v") + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T, C) + out = self.proj_out(out) + return out, attn + +class sAuxRPECrossAttention(nn.Module): + """ RPE cross attention with auxillary variable + + """ + def __init__(self, n_embd, num_heads, edge_dim,aux_edge_func=None, use_checkpoint=False, + use_rpe_net=None, use_rpe_k=True, use_rpe_v=True + ): + super().__init__() + # assert edge_dim% num_heads==0 + self.edge_dim = edge_dim + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+edge_dim, n_embd) + self.vnet = nn.Linear(n_embd+edge_dim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + self.aux_edge_func = aux_edge_func + + if use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x1, x2, attn_mask, aux_x1,aux_x2, frame_indices1, frame_indices2, edge = None,attn_weights_list=None): + out, attn = checkpoint(self._forward, (x1, x2, attn_mask, aux_x1, aux_x2, frame_indices1, frame_indices2,edge), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T1, C = x1.shape + T2 = x2.shape[1] + attn_weights_list.append(attn.detach().view(B, -1, T1, T2).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x1, x2, attn_mask, aux_x1, aux_x2, frame_indices1, frame_indices2,edge=None): + B, T1, C = x1.shape + T2 = x2.shape[1] + q = self.qnet(x1) + if self.edge_dim>0: + if edge is None: + if self.aux_edge_func is None: + # default to concatenation + edge = torch.cat([aux_x1.unsqueeze(2).repeat_interleave(T2,2), + aux_x2.unsqueeze(1).repeat_interleave(T1,1)],-1) # (B,T1,T2,auxdim1+auxdim2) + else: + edge = self.aux_edge_func(aux_x1,aux_x2) # (B,T1,T2,aux_vardim) + if self.knet.weight.dtype==torch.float16: + edge = edge.half() + aug_x2 = torch.cat([x2.unsqueeze(1).repeat_interleave(T1,1), edge], dim=-1) + else: + aug_x2 = x2.unsqueeze(1).repeat_interleave(T1,1) + + k = self.knet(aug_x2).reshape(B,T1, T2, self.num_heads,C//self.num_heads).permute(0, 3, 1, 2,4) #B,nh,T1,T2,C//nh + v = self.vnet(aug_x2).reshape(B,T1,T2,self.num_heads,C//self.num_heads).permute(0, 3, 1, 2,4) #B,nh,T1,T2,C//nh + + q = (q*self.scale).view(B, T1, self.num_heads, 1, C // self.num_heads).transpose(1, 2).repeat_interleave(T2,3) # (B, nh, T1, T2, hs) + attn = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + if self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices1.unsqueeze(-1) - frame_indices2.unsqueeze(-2)) # BxT1xT2 + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk",P1=frame_indices1.unsqueeze(-1).expand(*pairwise_distances.shape),\ + P2=frame_indices2.unsqueeze(-2).expand(*pairwise_distances.shape)) + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, T1, T2) # BxHxTxT + w.masked_fill_((attn_mask[:,None]==0).all(-1).unsqueeze(-1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v if v.ndim==4 else (attn.unsqueeze(-1)*v).sum(-2) + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v",P1=frame_indices1.unsqueeze(-1).expand(*pairwise_distances.shape),\ + P2=frame_indices2.unsqueeze(-2).expand(*pairwise_distances.shape)) + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T1, C) + out = self.proj_out(out) + return out, attn + +def testAux(): + + net = sAuxRPEAttention(n_embd=32, num_heads=4, aux_vardim=4,use_rpe_net=True) + B,C,T = 3,32,10 + x = torch.randn(B,T,C) + aux_x = torch.randn(B,T,4) + frame_indices = torch.arange(T).unsqueeze(0).repeat(B,1) + out = net(x, aux_x, None, frame_indices) + +def testAuxCross(): + net = sAuxRPECrossAttention(n_embd=32, num_heads=4, edge_dim=4,use_rpe_net=True) + B,C,T1,T2 = 3,32,10,5 + x1 = torch.randn(B,T1,C) + x2 = torch.randn(B,T2,C) + aux_x1 = torch.randn(B,T1,4) + aux_x2 = torch.randn(B,T2,4) + frame_indices1 = torch.arange(T1).unsqueeze(0).repeat(B,1) + frame_indices2 = torch.arange(T2).unsqueeze(0).repeat(B,1)+T1 + out = net(x1,x2, None, aux_x1, aux_x2, frame_indices1, frame_indices2) + print("123") + +if __name__ == "__main__": + testAuxCross() \ No newline at end of file diff --git a/diffstack/models/Transformer.py b/diffstack/models/Transformer.py new file mode 100644 index 0000000..faffc5b --- /dev/null +++ b/diffstack/models/Transformer.py @@ -0,0 +1,861 @@ +from logging import raiseExceptions +import numpy as np + +import torch +import math, copy +from typing import Dict +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import diffstack.utils.tensor_utils as TensorUtils + + +def clones(module, n): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +class FactorizedEncoderDecoder(nn.Module): + """ + A encoder-decoder transformer model with Factorized encoder and decoder + """ + + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator, src2posfun): + """ + Args: + encoder: FactorizedEncoder + decoder: FactorizedDecoder + src_embed: source embedding network + tgt_embed: target embedding network + generator: network used to generate output from target + src2posfun: extract positional info from the src + """ + super(FactorizedEncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + self.src2posfun = src2posfun + + def src2pos(self, src, dyn_type): + "extract positional info from src for all datatypes, e.g., for vehicles, the first two dimensions are x and y" + + pos = torch.zeros([*src.shape[:-1], 2]).to(src.device) + for dt, fun in self.src2posfun.items(): + pos += fun(src) * (dyn_type == dt).view([*(dyn_type.shape), 1, 1]) + + return pos + + def forward( + self, + src, + tgt, + src_mask, + tgt_mask, + tgt_mask_agent, + dyn_type, + map_emb=None, + ): + "Take in and process masked src and target sequences." + src_pos = self.src2pos(src, dyn_type) + "for decoders, we only use position at the last time step of the src" + return self.decode( + self.encode(src, src_mask, src_pos, map_emb), + src_mask, + tgt, + tgt_mask, + tgt_mask_agent, + src_pos[:, :, -1:], + ) + + def encode(self, src, src_mask, src_pos, map_emb): + return self.encoder(self.src_embed(src), src_mask, src_pos, map_emb) + + def decode(self, memory, src_mask, tgt, tgt_mask, tgt_mask_agent, pos): + + return self.decoder( + self.tgt_embed(tgt), + memory, + src_mask, + tgt_mask, + tgt_mask_agent, + pos, + ) + + +class DynamicGenerator(nn.Module): + "Incorporating dynamics to the generator to generate dynamically feasible output, not used yet" + + def __init__(self, d_model, dt, dyns, state2feature, feature2state): + super(DynamicGenerator, self).__init__() + self.dyns = dyns + self.proj = dict() + self.dt = dt + self.state2feature = state2feature + self.feature2state = feature2state + for dyn in self.dyns: + self.proj[dyn.type()] = nn.Linear(d_model, dyn.udim) + + def forward(self, x, tgt, type_index): + Nagent = tgt.shape[0] + tgt_next = [None] * Nagent + for dyn in self.dyns: + index = type_index[dyn.type()] + state = self.feature2state[dyn.name](tgt[index]) + input = self.proj[dyn.type()](x) + state_next = dyn.step(state, input, self.dt) + x_next_raw = self.state2feature[dyn.name](state_next) + for i in range(len(index)): + tgt_next[index[i]] = x_next_raw[i] + return torch.stack(tgt_next, dim=0) + + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask, mask1=None): + "Pass the input (and mask) through each layer in turn." + for i, layer in enumerate(self.layers): + if i == 0: + x = layer(x, mask) + else: + if mask1 is None: + x = layer(x, mask) + else: + x = layer(x, mask1) + return self.norm(x) + + +class FactorizedEncoder(nn.Module): + def __init__(self, temporal_enc, agent_enc, temporal_pe, XY_pe, N_layer=1): + """ + Factorized encoder and agent axis + Args: + temporal_enc: encoder with attention over temporal axis + agent_enc: encoder with attention over agent axis + temporal_pe: positional encoding over time + XY_pe: positional encoding over XY coordinates + """ + super(FactorizedEncoder, self).__init__() + self.N_layer = N_layer + self.temporal_encs = clones(temporal_enc, N_layer) + self.agent_encs = clones(agent_enc, N_layer) + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + + def forward(self, x, src_mask, src_pos, map_emb): + """Pass the input (and mask) through each layer in turn. + Args: + x:[B,Num_agent,T,d_model] + src_mask:[B,Num_agent,T] + src_pos:[B,Num_agent,T,2] + map_emb: [B,Num_agent,1,map_emb_dim] output of the CNN ROI map encoder + Returns: + embedding of size [B,Num_agent,T,d_model] + """ + + if map_emb.ndim == 3: + map_emb = map_emb.unsqueeze(2).repeat(1, 1, x.size(2), 1) + x = ( + torch.cat( + ( + x, + self.XY_pe(x, src_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + map_emb, + ), + dim=-1, + ) + * src_mask.unsqueeze(-1) + ) + for i in range(self.N_layer): + x = self.agent_encs[i](x, src_mask) + x = self.temporal_encs[i](x, src_mask) + return x + + +class StaticEncoder(nn.Module): + def __init__(self, agent_enc, XY_pe, N_layer=1): + """ + Factorized encoder and agent axis + Args: + temporal_enc: encoder with attention over temporal axis + agent_enc: encoder with attention over agent axis + temporal_pe: positional encoding over time + XY_pe: positional encoding over XY coordinates + """ + super(StaticEncoder, self).__init__() + self.N_layer = N_layer + self.agent_encs = clones(agent_enc, N_layer) + self.XY_pe = XY_pe + + def forward(self, x, src_mask, src_pos, map_emb=None): + """Pass the input (and mask) through each layer in turn. + Args: + x:[B,Num_agent,T,d_model] + src_mask:[B,Num_agent,T] + src_pos:[B,Num_agent,T,2] + map_emb: [B,Num_agent,1,map_emb_dim] output of the CNN ROI map encoder + Returns: + embedding of size [B,Num_agent,T,d_model] + """ + inputs = [x, self.XY_pe(x, src_pos)] + if map_emb is not None: + inputs.append(map_emb) + + x = ( + torch.cat( + ( + inputs + ), + dim=-1, + ) + * src_mask.unsqueeze(-1) + ) + for i in range(self.N_layer): + x = self.agent_encs[i](x, src_mask) + return x + + +class LayerNorm(nn.Module): + "Construct a layernorm module" + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward" + + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "self attention followed by feedforward, residual and batch norm in between layers" + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask): + "cross attention to the embedding generated by the encoder" + for layer in self.layers: + x = layer(x, memory, src_mask, tgt_mask) + return self.norm(x) + + +class SummaryModel(nn.Module): + """ + map the scene information to attributes that summarizes the scene + """ + + def __init__(self, encoder, decoder, src_embed, src2posfun): + super(SummaryModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.src2posfun = src2posfun + + def src2pos(self, src, dyn_type): + "extract positional info from src for all datatypes, e.g., for vehicles, the first two dimensions are x and y" + + pos = torch.zeros([*src.shape[:-1], 2]).to(src.device) + for dt, fun in self.src2posfun.items(): + pos += fun(src) * (dyn_type == dt).view([*(dyn_type.shape), 1, 1]) + + return pos + + def forward( + self, + src, + src_mask, + dyn_type, + map_emb, + ): + "Take in and process masked src and target sequences." + src_pos = self.src2pos(src, dyn_type) + return self.decode( + self.encode(src, src_mask, src_pos, map_emb), + src_mask, + ) + + def encode(self, src, src_mask, src_pos, map_emb): + return self.encoder(self.src_embed(src), src_mask, src_pos, map_emb) + + def decode(self, memory, src_mask): + return self.decoder(memory, src_mask) + + +class SummaryDecoder(nn.Module): + """ + Map the encoded tensor to a description of the whole scene, e.g., the likelihood of certain modes + """ + + def __init__( + self, temporal_attn, agent_attn, ff, emb_dim, output_dim, static=False + ): + super(SummaryDecoder, self).__init__() + self.temporal_attn = temporal_attn + self.agent_attn = agent_attn + self.ff = ff + self.output_dim = output_dim + self.static = static + self.MLP = nn.Sequential(nn.Linear(emb_dim, output_dim), nn.Sigmoid()) + + def forward(self, x, mask): + x = self.agent_attn(x, x, x, mask) + x = self.ff(torch.max(x, dim=-3)[0]).unsqueeze(1) + if not self.static: + x = self.temporal_attn(x, x, x) + x = torch.max(x, dim=-2)[0].squeeze(1) + x = self.MLP(x) + return x + + +class FactorizedDecoder(nn.Module): + """ + Args: + temporal_dec: decoder with attention over temporal axis + agent_enc: decoder with attention over agent axis + temporal_pe: positional encoding for time axis + XY_pe: positional encoding for XY axis + """ + + def __init__( + self, + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + N_layer_enc=1, + N_layer_dec=1, + ): + super(FactorizedDecoder, self).__init__() + self.temporal_dec = clones(temporal_dec, N_layer_dec) + self.agent_enc = clones(agent_enc, N_layer_enc) + self.temporal_enc = clones(temporal_enc, N_layer_enc) + self.N_layer_enc = N_layer_enc + self.N_layer_dec = N_layer_dec + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + + def forward(self, x, memory, src_mask, tgt_mask, tgt_mask_agent, pos): + """ + Pass the input (and mask) through each layer in turn. + Args: + x (torch.tensor)): [batch,Num_agent,T_tgt,d_model] + memory (torch.tensor): [batch,Num_agent,T_src,d_model] + src_mask (torch.tensor): [batch,Num_agent,T_src] + tgt_mask (torch.tensor): [batch,Num_agent,T_tgt] + tgt_mask_agent (torch.tensor): [batch,Num_agent,T_tgt] + pos (torch.tensor): [batch,Num_agent,1,2] + + Returns: + torch.tensor: [batch,Num_agent,T_tgt,d_model] + """ + T = x.size(-2) + tgt_pos = pos.repeat([1, 1, T, 1]) + + x = ( + torch.cat( + ( + x, + self.XY_pe(x, tgt_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + ), + dim=-1, + ) + * tgt_mask_agent.unsqueeze(-1) + ) + + for i in range(self.N_layer_dec): + x = self.temporal_dec[i](x, memory, src_mask, tgt_mask) + prob = torch.ones(x.shape[0]).to(x.device) + return x * tgt_mask_agent.unsqueeze(-1), prob + + +class MultimodalFactorizedDecoder(nn.Module): + """ + Args: + temporal_dec: decoder with attention over temporal axis + agent_enc: decoder with attention over agent axis + temporal_pe: positional encoding for time axis + XY_pe: positional encoding for XY axis + """ + + def __init__( + self, + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + M, + summary_dec, + N_layer_enc=1, + N_layer_dec=1, + ): + super(MultimodalFactorizedDecoder, self).__init__() + self.M = M + self.temporal_dec = clones(temporal_dec, N_layer_dec) + self.agent_enc = clones(agent_enc, N_layer_enc) + self.temporal_enc = clones(temporal_enc, N_layer_enc) + self.N_layer_enc = N_layer_enc + self.N_layer_dec = N_layer_dec + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + self.summary_dec = summary_dec + + def forward(self, x, memory, src_mask, tgt_mask, tgt_mask_agent, pos): + """ + Pass the input (and mask) through each layer in turn. + Args: + x (torch.tensor)): [batch,Num_agent,T_tgt,d_model] + memory (torch.tensor): [batch,Num_agent,T_src,d_model] + src_mask (torch.tensor): [batch,Num_agent,T_src] + tgt_mask (torch.tensor): [batch,Num_agent,T_tgt] + tgt_mask_agent (torch.tensor): [batch,Num_agent,T_tgt] + pos (torch.tensor): [batch,Num_agent,1,2] + + Returns: + torch.tensor: [batch,Num_agent,T_tgt,d_model] + """ + T = x.size(-2) + tgt_pos = pos.repeat([1, 1, T, 1]) + + x = ( + torch.cat( + ( + x, + self.XY_pe(x, tgt_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + ), + dim=-1, + ) + * tgt_mask_agent.unsqueeze(-1) + ) + + # adding one-hot encoding of the modes + modes_enc = ( + F.one_hot(torch.arange(0, self.M)) + .view(1, self.M, 1, 1, self.M) + .repeat(x.size(0), 1, x.size(1), x.size(2), 1) + ).to(x.device) + + x = torch.cat((x.unsqueeze(1).repeat(1, self.M, 1, 1, 1), modes_enc), dim=-1) + + memory_M = memory.unsqueeze(1).repeat(1, self.M, 1, 1, 1) + src_mask_M = src_mask.unsqueeze(1).repeat(1, self.M, 1, 1) + tgt_mask_M = tgt_mask.unsqueeze(1).repeat(1, self.M, 1, 1, 1) + tgt_mask_agent_M = tgt_mask_agent.unsqueeze(1).repeat(1, self.M, 1, 1) + for i in range(self.N_layer_enc): + x = self.agent_enc[i](x, tgt_mask_agent_M) + x = self.temporal_enc[i](x, tgt_mask_M) + for i in range(self.N_layer_dec): + x = self.temporal_dec[i]( + x, + memory_M, + src_mask_M, + tgt_mask_M, + ) + + prob = self.summary_dec(x, tgt_mask_agent_M).squeeze(-1) + prob = F.softmax(prob, dim=-1) + return x * tgt_mask_agent_M.unsqueeze(-1), prob + + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask): + "self attention followed by cross attention with the encoder output" + + m = memory + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") + return torch.from_numpy(subsequent_mask) == 0 + + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1, pooling_dim=None): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.pooling_dim = pooling_dim + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if self.pooling_dim is None: + pooling_dim = -2 + else: + pooling_dim = self.pooling_dim + + if mask is not None: + # Same mask applied to all h heads. + if mask.ndim == query.ndim - 1: + mask = mask.view([*mask.shape, 1, 1]).transpose(-1, pooling_dim - 1) + elif mask.ndim == query.ndim: + mask = mask.unsqueeze(-2).transpose(-2, pooling_dim - 1) + else: + raise Exception("mask dimension mismatch") + + # 1) Do all the linear projections in batch from d_model => h x d_k + + query, key, value = [ + l(x).view(*x.shape[:-1], self.h, self.d_k) + for l, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention( + query.transpose(-2, pooling_dim - 1), + key.transpose(-2, pooling_dim - 1), + value.transpose(-2, pooling_dim - 1), + mask, + dropout=self.dropout, + ) + + x = x.transpose(-2, pooling_dim - 1).contiguous() + x = x.view(*x.shape[:-2], self.h * self.d_k) + + # 3) "Concat" using a view and apply a final linear. + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + + def __init__(self, dim, dropout, max_len=5000, flipped=False): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.flipped = flipped + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + if self.flipped: + position = -position.flip(dims=[0]) + div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + pe_shape = [1] * (x.ndim - 2) + list(x.shape[-2:-1]) + [self.dim] + if self.flipped: + return self.dropout( + Variable(self.pe[:, -x.size(-2) :].view(pe_shape), requires_grad=False) + ) + else: + return self.dropout( + Variable(self.pe[:, : x.size(-2)].view(pe_shape), requires_grad=False) + ) + + +class PositionalEncodingNd(nn.Module): + "extension of the PE function, works for N dimensional position input" + + def __init__(self, dim, dropout, step_size=[1]): + """ + step_size: scale of each dimension, pos/step_size = phase for the sinusoidal PE + """ + super(PositionalEncodingNd, self).__init__() + assert dim % 2 == 0 + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.step_size = step_size + self.D = len(step_size) + self.pe = list() + + # Compute the positional encodings once in log space. + self.div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim)) + + def forward(self, x, pos): + rep_size = [1] * (x.ndim) + rep_size[-1] = int(self.dim / 2) + pe_shape = [*x.shape[:-1], self.dim] + for i in range(self.D): + pe = torch.zeros(pe_shape).to(x.device) + + pe[..., 0::2] = torch.sin( + pos[..., i : i + 1].repeat(*rep_size) + / self.step_size[i] + * self.div_term.to(x.device) + ) + pe[..., 1::2] = torch.sin( + pos[..., i : i + 1].repeat(*rep_size) + / self.step_size[i] + * self.div_term.to(x.device) + ) + return self.dropout(Variable(pe, requires_grad=False)) + + +def make_transformer_model( + src_dim, + tgt_dim, + out_dim, + dyn_list, + N_t=6, + N_a=3, + d_model=384, + XY_pe_dim=64, + temporal_pe_dim=64, + map_emb_dim=128, + d_ff=2048, + head=8, + dropout=0.1, + step_size=[0.1, 0.1], + N_layer_enc=1, + N_layer_tgt_enc=1, + N_layer_tgt_dec=1, + M=1, + use_GAN=False, + GAN_static=True, + N_layer_enc_discr=1, +): + "first generate the building blocks, attn networks, encoders, decoders, PEs and Feedforward nets" + c = copy.deepcopy + temporal_attn = MultiHeadedAttention(head, d_model) + agent_attn = MultiHeadedAttention(head, d_model, pooling_dim=-3) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + temporal_pe = PositionalEncoding(temporal_pe_dim, dropout) + temporal_pe_flip = PositionalEncoding(temporal_pe_dim, dropout, flipped=True) + XY_pe = PositionalEncodingNd(XY_pe_dim, dropout, step_size=step_size) + temporal_enc = Encoder(EncoderLayer(d_model, c(temporal_attn), c(ff), dropout), N_t) + agent_enc = Encoder(EncoderLayer(d_model, c(agent_attn), c(ff), dropout), N_a) + + src_emb = nn.Linear(src_dim, d_model - XY_pe_dim - temporal_pe_dim - map_emb_dim) + if M == 1: + tgt_emb = nn.Linear(tgt_dim, d_model - XY_pe_dim - temporal_pe_dim) + else: + tgt_emb = nn.Linear(tgt_dim, d_model - XY_pe_dim - temporal_pe_dim - M) + generator = nn.Linear(d_model, out_dim) + + temporal_dec = Decoder( + DecoderLayer(d_model, c(temporal_attn), c(temporal_attn), c(ff), dropout), N_t + ) + "gather src2posfun from all agent types" + src2posfun = {D.type(): D.state2pos for D in dyn_list} + + Factorized_Encoder = FactorizedEncoder( + c(temporal_enc), c(agent_enc), temporal_pe_flip, XY_pe, N_layer_enc + ) + if M == 1: + Factorized_Decoder = FactorizedDecoder( + c(temporal_dec), + c(agent_enc), + c(temporal_enc), + temporal_pe, + XY_pe, + N_layer_tgt_enc, + N_layer_tgt_dec, + ) + else: + mode_summary_dec = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1 + ) + Factorized_Decoder = MultimodalFactorizedDecoder( + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + M, + mode_summary_dec, + N_layer_enc=1, + N_layer_dec=1, + ) + Factorized_Encoder = FactorizedEncoder( + c(temporal_enc), c(agent_enc), temporal_pe_flip, XY_pe, N_layer_enc + ) + if use_GAN: + if GAN_static: + Summary_Encoder = StaticEncoder( + c(agent_enc), + XY_pe, + N_layer_enc_discr, + ) + Summary_Decoder = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1, static=True + ) + static_src_emb = nn.Linear(src_dim, d_model - XY_pe_dim - map_emb_dim) + Summary_Model = SummaryModel( + Summary_Encoder, + Summary_Decoder, + c(static_src_emb), + src2posfun, + ) + else: + Summary_Encoder = Summary_Encoder = FactorizedEncoder( + c(temporal_enc), + c(agent_enc), + temporal_pe_flip, + XY_pe, + N_layer_enc_discr, + ) + Summary_Decoder = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1, static=True + ) + Summary_Model = SummaryModel( + Summary_Encoder, + Summary_Decoder, + c(src_emb), + src2posfun, + ) + + else: + Summary_Model = None + "use a simple nn.Linear as the generator as our output is continuous" + + Transformer_Model = FactorizedEncoderDecoder( + Factorized_Encoder, + Factorized_Decoder, + c(src_emb), + c(tgt_emb), + c(generator), + src2posfun, + ) + + return Transformer_Model, Summary_Model + + +class SimpleTransformer(nn.Module): + def __init__( + self, + src_dim, + N_a=3, + d_model=384, + XY_pe_dim=64, + d_ff=2048, + head=8, + dropout=0.1, + step_size=[0.1, 0.1], + ): + super(SimpleTransformer, self).__init__() + c = copy.deepcopy + agent_attn = MultiHeadedAttention(head, d_model, pooling_dim=-3) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + XY_pe = PositionalEncodingNd(XY_pe_dim, dropout, step_size=step_size) + self.agent_enc = StaticEncoder(EncoderLayer(d_model, c(agent_attn), c(ff), dropout), XY_pe, N_a) + self.pre_emb = nn.Linear(src_dim, d_model - XY_pe_dim) + self.post_emb = nn.Linear(d_model, src_dim) + + def forward(self, feats, avails, pos): + x = self.pre_emb(feats) + x = self.agent_enc(x, avails, pos) + return self.post_emb(x) + + +class simplelinear(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim=[64, 32]): + super(simplelinear, self).__init__() + self.hidden_dim = hidden_dim + self.hidden_layers = len(hidden_dim) + self.fhidden = nn.ModuleList() + + for i in range(1, self.hidden_layers): + self.fhidden.append(nn.Linear(hidden_dim[i - 1], hidden_dim[i])) + + self.f1 = nn.Linear(input_dim, hidden_dim[0]) + self.f2 = nn.Linear(hidden_dim[-1], output_dim) + + def forward(self, x): + hidden = self.f1(x) + for i in range(1, self.hidden_layers): + hidden = self.fhidden[i - 1](F.relu(hidden)) + return self.f2(F.relu(hidden)) diff --git a/diffstack/models/TypeTransformer.py b/diffstack/models/TypeTransformer.py new file mode 100644 index 0000000..d517f4f --- /dev/null +++ b/diffstack/models/TypeTransformer.py @@ -0,0 +1,791 @@ +""" +Transformer that accepts multiple types with separate attention +Adopted from the minGPT project https://github.com/karpathy/minGPT + + +""" + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ----------------------------------------------------------------------------- + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, mask=None): + # mask: (B,T,T) + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class AuxSelfAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__( + self, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + PE_len=None, + ): + super().__init__() + assert n_embd % n_head == 0 + # assert edge_dim % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.q_net = nn.Linear(n_embd, n_embd) + self.k_net = nn.Linear(n_embd + edge_dim, n_embd) + self.v_net = nn.Linear(n_embd + edge_dim, n_embd) + self.aux_edge_func = aux_edge_func + + self.PE_len = PE_len + if PE_len is not None: + self.PE_q = nn.Embedding(PE_len, n_embd) + self.PE_k = nn.Embedding(PE_len, n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, aux_x, mask=None, edge=None, frame_indices=None): + # mask: (B,T,T) + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.q_net(x) + if edge is not None or self.aux_edge_func is not None: + T = aux_x.size(1) + if edge is None: + edge = self.aux_edge_func(aux_x, aux_x) # (B,T,T,aux_vardim) + if self.k_net.weight.dtype == torch.float16: + edge = edge.half() + aug_x = torch.cat([x.unsqueeze(2).repeat_interleave(T, 2), edge], dim=-1) + k = self.k_net(aug_x) + v = self.v_net(aug_x) + k = k.view(B, T, T, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T, T, hs) + q = ( + q.view(B, T, self.n_head, 1, C // self.n_head) + .transpose(1, 2) + .repeat_interleave(T, 3) + ) # (B, nh, T, T, hs) + v = v.view(B, T, T, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T, T, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices).view( + B, T, self.n_head, 1, C // self.n_head + ).transpose(1, 2).repeat_interleave(T, 3) + k = k + self.PE_k(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2).unsqueeze(2).repeat_interleave(T, 2) + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + else: + aug_x = torch.cat([x, aux_x], dim=-1) + k = self.k_net(aug_x) + v = self.v_net(aug_x) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2) + k = k + self.PE_k(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2) + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + # if v.ndim==4, (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) elif v.ndim==5, (B, nh, T, T) x (B, nh, T, T, hs) -> (B, nh, T, hs) + y = att @ v if v.ndim == 4 else (att.unsqueeze(-1) * v).sum(-2) + + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class AuxCrossAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__( + self, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + PE_len=None, + ): + super().__init__() + assert n_embd % n_head == 0 + # assert edge_dim % n_head == 0 + self.edge_dim = edge_dim + # key, query, value projections for all heads, but in a batch + self.q_net = nn.Linear(n_embd, n_embd) + self.k_net = nn.Linear(n_embd + edge_dim, n_embd) + self.v_net = nn.Linear(n_embd + edge_dim, n_embd) + self.PE_len = PE_len + if PE_len is not None: + self.PE_q = nn.Embedding(PE_len, n_embd) + self.PE_k = nn.Embedding(PE_len, n_embd) + self.aux_edge_func = aux_edge_func + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward( + self, + x1, + x2, + mask, + aux_x1, + aux_x2, + frame_indices1=None, + frame_indices2=None, + edge=None, + ): + # mask: (B,T,T) + ( + B, + T1, + C, + ) = x1.size() # batch size, sequence length, embedding dimensionality (n_embd) + T2 = x2.size(1) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.q_net(x1) + if self.edge_dim > 0: + if edge is None: + if self.aux_edge_func is None: + # default to concatenation + edge = torch.cat( + [ + aux_x1.unsqueeze(2).repeat_interleave(T2, 2), + aux_x2.unsqueeze(1).repeat_interleave(T1, 1), + ], + -1, + ) # (B,T1,T2,auxdim1+auxdim2) + else: + edge = self.aux_edge_func(aux_x1, aux_x2) # (B,T1,T2,aux_vardim) + if self.k_net.weight.dtype == torch.float16: + edge = edge.half() + aug_x2 = torch.cat([x2.unsqueeze(1).repeat_interleave(T1, 1), edge], dim=-1) + else: + aug_x2 = x2.unsqueeze(1).repeat_interleave(T1, 1) + + k = self.k_net(aug_x2) + v = self.v_net(aug_x2) + k = k.view(B, T1, T2, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T1, T2, hs) + q = ( + q.view(B, T1, self.n_head, 1, C // self.n_head) + .transpose(1, 2) + .repeat_interleave(T2, 3) + ) # (B, nh, T1, T2, hs) + v = v.view(B, T1, T2, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T1, T2, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices1).view( + B, T1, self.n_head, 1, C // self.n_head + ).transpose(1, 2).repeat_interleave(T2, 3) + k = k + self.PE_k(frame_indices2).view( + B, T2, self.n_head, C // self.n_head + ).transpose(1, 2).unsqueeze(2).repeat_interleave(T1, 2) + # # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + # if v.ndim==4, (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) elif v.ndim==5, (B, nh, T, T) x (B, nh, T, T, hs) -> (B, nh, T, hs) + y = att @ v if v.ndim == 4 else (att.unsqueeze(-1) * v).sum(-2) + y = ( + y.transpose(1, 2).contiguous().view(B, T1, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class CrossAttention(nn.Module): + """ + A vanilla multi-head cross-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_q = nn.Linear(n_embd, n_embd) + self.c_k = nn.Linear(n_embd, n_embd) + self.c_v = nn.Linear(n_embd, n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_pdrop = attn_pdrop + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, z, mask=None): + # x: (B, T_in, C), query + # z: (B, T_m, C), memory + # mask: (B, T_in, T_m), mask for memory + ( + B, + T_in, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + T_m = z.size(1) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.c_q(x) + k, v = self.c_k(z), self.c_v(z) + + k = k.view(B, T_m, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_m, hs) + q = q.view(B, T_in, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_in, hs) + v = v.view(B, T_m, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_m, hs) + + # cross-attention; Cross-attend: (B, nh, T_in, hs) x (B, nh, hs, T_m) -> (B, nh, T_in, T_m) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T_in, T_m) x (B, nh, T_m, hs) -> (B, nh, T_in, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T_in, C) + ) # re-assemble all head outputs side by side + + y = self.resid_dropout(self.c_proj(y)) + + return y + + +class TypeSelfAttention(nn.Module): + """a composite attention block supporting multiple attention types""" + + def __init__( + self, + ntype, + n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + ): + super().__init__() + self.ntype = ntype + self.edge_dim = edge_dim + if edge_dim > 0: + self.attn = nn.ModuleList( + [ + AuxSelfAttention( + n_embd, n_head, edge_dim, aux_edge_func, attn_pdrop, resid_pdrop + ) + for _ in range(self.ntype) + ] + ) + else: + self.attn = nn.ModuleList( + [ + SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + + def forward(self, x, aux_x, masks, edge=None): + if self.ntype == 1 and isinstance(masks, torch.Tensor): + masks = [masks] + assert len(masks) == self.ntype + if edge is not None: + if isinstance(edge, list): + assert len(edge) == self.ntype + elif isinstance(edge, torch.Tensor): + edge = [edge] * self.ntype + else: + edge = [None] * self.ntype + resid = torch.zeros_like(x) + if self.edge_dim > 0: + assert aux_x is not None + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x, aux_x, mask, edge=edge[i]) + else: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x, mask, edge=edge[i]) + return resid + + +class TypeCrossAttention(nn.Module): + """a composite attention block supporting multiple attention types""" + + def __init__( + self, + ntype, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + ): + super().__init__() + self.ntype = ntype + self.edge_dim = edge_dim + if edge_dim > 0: + self.attn = nn.ModuleList( + [ + AuxCrossAttention( + n_embd, n_head, edge_dim, aux_edge_func, attn_pdrop, resid_pdrop + ) + for _ in range(self.ntype) + ] + ) + else: + self.attn = nn.ModuleList( + [ + CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + + def forward(self, x1, x2, masks, aux_x1=None, aux_x2=None, edge=None): + if self.ntype == 1 and isinstance(masks, torch.Tensor): + masks = [masks] + assert len(masks) == self.ntype + if edge is not None: + if isinstance(edge, list): + assert len(edge) == self.ntype + elif isinstance(edge, torch.Tensor): + edge = [edge] * self.ntype + else: + edge = [None] * self.ntype + resid = torch.zeros_like(x1) + if self.edge_dim > 0: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x1, x2, mask, aux_x1, aux_x2, edge=edge[i]) + else: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x1, x2, mask, edge=edge[i]) + return resid + + +class EncBlock(nn.Module): + """an unassuming Transformer encoder block""" + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, mask=None): + x = x + self.attn(self.ln_1(x), mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class DecBlock(nn.Module): + """an unassuming Transformer decoder block""" + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, z, mask=None): + x = x + self.attn(self.ln_1(x), z, mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TypeEncBlock(nn.Module): + """an unassuming Transformer encoder block supporting multiple attention types""" + + def __init__(self, n_embd, ntype, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.ntype = ntype + self.attn = nn.ModuleList( + [ + SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, masks): + assert len(masks) == self.ntype + for i, mask in enumerate(masks): + x = x + self.attn[i](self.ln_1(x), mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TypeDecBlock(nn.Module): + """an unassuming Transformer decoder block supporting multiple attention types""" + + def __init__(self, n_embd, ntype, n_head, attn_pdrop, resid_pdrop): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.ntype = ntype + self.attn = nn.ModuleList( + [ + CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, z, masks): + assert len(masks) == self.ntype + for i, mask in enumerate(masks): + x = x + self.attn[i](self.ln_1(x), z, mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + n_embd, + ntype, + n_head, + attn_pdrop, + resid_pdrop, + nblock, + indim=None, + outdim=None, + ): + super().__init__() + self.nblock = nblock + self.indim = indim + self.outdim = outdim + self.in_embd = ( + nn.Linear(self.indim, n_embd) + if self.indim is not None + else torch.nn.Identity() + ) + self.out_proj = ( + nn.Linear(n_embd, self.outdim) + if self.outdim is not None + else torch.nn.Identity() + ) + self.ntype = ntype + if self.ntype == 1: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): EncBlock( + n_embd, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + else: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): TypeEncBlock( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + + def forward(self, x, mask=None): + x = self.in_embd(x) + for i in range(self.nblock): + x = self.blocks["block_{}".format(i)](x, mask) + x = self.out_proj(x) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + n_embd, + ntype, + n_head, + attn_pdrop, + resid_pdrop, + nblock, + indim=None, + outdim=None, + ): + super().__init__() + self.nblock = nblock + self.indim = indim + self.outdim = outdim + self.in_embd = ( + nn.Linear(self.indim, n_embd) + if self.indim is not None + else torch.nn.Identity() + ) + self.out_proj = ( + nn.Linear(n_embd, self.outdim) + if self.outdim is not None + else torch.nn.Identity() + ) + self.ntype = ntype + if self.ntype == 1: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): DecBlock( + n_embd, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + else: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): TypeDecBlock( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + + def forward(self, x, z, mask=None): + x = self.in_embd(x) + for i in range(self.nblock): + x = self.blocks["block_{}".format(i)](x, z, mask) + x = self.out_proj(x) + return x + + +def test(): + n_embd = 256 + ntype = 3 + n_head = 8 + attn_pdrop = 0.1 + resid_pdrop = 0.1 + enc = TransformerEncoder( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop, nblock=2, indim=15 + ).cuda() + dec = TransformerDecoder( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop, nblock=2, indim=15, outdim=40 + ).cuda() + b = 2 + T_in = 10 + T_m = 20 + xin = torch.randn(b, T_in, 15).cuda() + + type = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2]).cuda() + type_flag = [(type == i)[None].repeat_interleave(b, 0) for i in range(3)] + enc_masks = [fg[..., None] * fg[..., None, :] for fg in type_flag] + + mask1 = torch.ones(2, T_in, T_m).bool().cuda() + mask2 = torch.tril(mask1) + mask3 = torch.triu(mask1) + dec_masks = [mask1, mask2, mask3] + mem = enc(xin, mask=enc_masks) + mem = torch.repeat_interleave(mem, 2, 1) + x1 = dec(xin, mem, [mask1, mask2, mask3]) + print("done") + + +def test_edge_func(): + n_embd = 256 + n_head = 8 + aux_vardim = 16 + aux_edge_func = lambda x, y: x.unsqueeze(2) * y.unsqueeze(1) + + attn = AuxCrossAttention(n_embd, n_head, aux_vardim, aux_edge_func) + b = 2 + N1 = 3 + N2 = 4 + x1 = torch.randn([b, N1, n_embd]) + aux_x1 = torch.randn([b, N1, aux_vardim]) + + x2 = torch.randn([b, N2, n_embd]) + aux_x2 = torch.randn([b, N2, aux_vardim]) + xx = attn(x1, x2, None, aux_x1, aux_x2) + print("123") + + +if __name__ == "__main__": + test_edge_func() diff --git a/diffstack/models/__init__.py b/diffstack/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/models/agentformer.py b/diffstack/models/agentformer.py new file mode 100644 index 0000000..e46f859 --- /dev/null +++ b/diffstack/models/agentformer.py @@ -0,0 +1,3324 @@ +import torch +from collections import OrderedDict +from dataclasses import asdict + +from diffstack import dynamics + +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) + +import numpy as np +from torch import nn +from torch.nn import functional as F +from collections import defaultdict +from diffstack.utils.model_utils import ( + AFMLP, + Normal, + Categorical, + initialize_weights, + rotation_2d_torch, + ExpParamAnnealer, +) +from .agentformer_lib import ( + AgentFormerEncoderLayer, + AgentFormerDecoderLayer, + AgentFormerDecoder, + AgentFormerEncoder, +) +from diffstack.models.agentformer_lib import * +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.loss_utils import MultiModal_trajectory_loss +from diffstack.models import base_models +from diffstack.utils.metrics import ( + DynOrnsteinUhlenbeckPerturbation, +) +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.loss_utils import ( + trajectory_loss, + MultiModal_trajectory_loss, + goal_reaching_loss, + collision_loss, + collision_loss_masked, + log_normal_mixture, + NLL_GMM_loss, + compute_pred_loss, + diversity_score, +) +from diffstack.utils.dist_utils import MAGaussian, MADynGaussian, MAGMM, MADynGMM + + +""" Positional Encoding """ + + +class PositionalAgentEncoding(nn.Module): + def __init__( + self, + d_model, + dropout=0.1, + max_t_len=200, + max_a_len=200, + concat=False, + use_agent_enc=False, + agent_enc_learn=False, + ): + super(PositionalAgentEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + self.concat = concat + self.d_model = d_model + self.use_agent_enc = use_agent_enc + if concat: + self.fc = nn.Linear((3 if use_agent_enc else 2) * d_model, d_model) + + pe = self.build_pos_enc(max_t_len) + self.register_buffer("pe", pe) + if use_agent_enc: + if agent_enc_learn: + self.ae = nn.Parameter(torch.randn(max_a_len, 1, d_model) * 0.1) + else: + ae = self.build_pos_enc(max_a_len) + self.register_buffer("ae", ae) + + def build_pos_enc(self, max_len): + pe = torch.zeros(max_len, self.d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + return pe + + def build_agent_enc(self, max_len): + ae = torch.zeros(max_len, self.d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model) + ) + ae[:, 0::2] = torch.sin(position * div_term) + ae[:, 1::2] = torch.cos(position * div_term) + ae = ae.unsqueeze(0).transpose(0, 1) + return ae + + def get_pos_enc(self, num_t, num_a, t_offset): + pe = self.pe[t_offset : num_t + t_offset, :] + pe = pe.repeat_interleave(num_a, dim=0) + return pe + + def get_agent_enc(self, num_t, num_a, a_offset, agent_enc_shuffle): + if agent_enc_shuffle is None: + ae = self.ae[a_offset : num_a + a_offset, :] + else: + ae = self.ae[agent_enc_shuffle] + ae = ae.repeat(num_t, 1, 1) + return ae + + def forward(self, x, num_a, agent_enc_shuffle=None, t_offset=0, a_offset=0): + num_t = x.shape[0] // num_a + pos_enc = self.get_pos_enc(num_t, num_a, t_offset) + if self.use_agent_enc: + agent_enc = self.get_agent_enc(num_t, num_a, a_offset, agent_enc_shuffle) + if self.concat: + feat = [x, pos_enc.repeat(1, x.size(1), 1)] + if self.use_agent_enc: + feat.append(agent_enc.repeat(1, x.size(1), 1)) + x = torch.cat(feat, dim=-1) + x = self.fc(x) + else: + x += pos_enc + if self.use_agent_enc: + x += agent_enc + return self.dropout(x) + + +""" Context (Past) Encoder """ + + +class ContextEncoder(nn.Module): + def __init__(self, cfg, **kwargs): + super().__init__() + self.cfg = cfg + self.motion_dim = cfg["motion_dim"] + self.model_dim = cfg["tf_model_dim"] + self.ff_dim = cfg["tf_ff_dim"] + self.nhead = cfg["tf_nhead"] + self.dropout = cfg["tf_dropout"] + self.nlayer = cfg["context_encoder"]["nlayer"] + self.input_type = cfg["input_type"] + self.pooling = cfg.pooling + self.agent_enc_shuffle = cfg["agent_enc_shuffle"] + self.vel_heading = cfg["vel_heading"] + in_dim = self.motion_dim * len(self.input_type) + if "map" in self.input_type: + in_dim += cfg.map_encoder.feature_dim - self.motion_dim + self.input_fc = nn.Linear(in_dim, self.model_dim) + + encoder_layers = AgentFormerEncoderLayer( + {}, self.model_dim, self.nhead, self.ff_dim, self.dropout + ) + self.tf_encoder = AgentFormerEncoder(encoder_layers, self.nlayer) + self.pos_encoder = PositionalAgentEncoding( + self.model_dim, + self.dropout, + concat=cfg["pos_concat"], + max_a_len=cfg["max_agent_len"], + use_agent_enc=cfg["use_agent_enc"], + agent_enc_learn=cfg["agent_enc_learn"], + ) + + def forward(self, data): + pre_len, agent_num, bs = ( + data["pre_motion"].size(0), + data["pre_motion"].size(1), + data["pre_motion"].size(2), + ) + PN = pre_len * agent_num + + # get raw features + traj_in = [] + for key in self.input_type: + if key == "pos": + traj_in.append(data["pre_motion"]) # P x N x B x 2 + elif key == "vel": + vel = data["pre_vel"] # P x N x B x 2 + # if len(self.input_type) > 1: + # vel = torch.cat([vel[[0]], vel], dim=0) + if self.vel_heading: + vel = rotation_2d_torch(vel, -data["heading"])[0] + traj_in.append(vel) + elif key == "norm": + traj_in.append(data["pre_motion_norm"]) # P x N x B x 2 + elif key == "scene_norm": + traj_in.append(data["pre_motion_scene_norm"]) # P x N x B x 2 + elif key == "heading": + hv = ( + data["heading_vec"].unsqueeze(0).repeat_interleave(pre_len, dim=0) + ) # P x N x B x 2 + traj_in.append(hv) + elif key == "map": + map_enc = data["map_enc"].unsqueeze(0).repeat((pre_len, 1, 1, 1)) + traj_in.append(map_enc) + else: + raise ValueError("unknown input_type!") + + # extend the agent-pair mask to PN x PN by repeating + # src_agent_mask = data['agent_mask'].clone() # N x N + # src_mask = generate_mask(tf_in.shape[0], tf_in.shape[0], data['agent_num'], src_agent_mask).to(tf_in.device) # PN X PN + + # ******************************** create mask for NaN + + # time-stamp based masking, i.e., not masking for a whole agents + # can only mask part of the agents who have incomplete data + src_mask = ( + data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) + ) # B x PN x 1 + src_mask_square = torch.bmm(src_mask, src_mask.transpose(1, 2)) # B x PN x PN + + # due to the inverse definition in attention.py + # 0 means good, 1 means nan data + enc_mask = (1 - src_mask.transpose(0, 1)).bool() # PN x B x 1 + src_mask_square = (1 - src_mask_square).bool() # B x PN x PN + + # expand mask to head dimensions + src_mask_square = ( + src_mask_square.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * self.nhead, PN, PN) + ) # BH x PN x PN + # repeat_interleave copy for the dimenion that already has sth, e.g., B + # attach the copied dimenion in the end, i.e., BH rather than HB + # the order matters in this case since there are a lot of dimenions + # when printing the matrices, the default is to loop/list from the + # 2nd dimenion, which is H in this case, same for PN (N dim goes first) + + # ******************************** feature encoding + + # mask NaN because even simple fc cannot handle NaN in backward pass + traj_in = torch.cat(traj_in, dim=-1) # P x N x B x feat + traj_in = traj_in.view(PN, bs, traj_in.shape[-1]) # PN x B x feat + traj_in = traj_in.masked_fill_(enc_mask, float(0)) # PN x B x feat + + # input projection + tf_in = self.input_fc(traj_in) # PN x B x feat + tf_in = tf_in.masked_fill_(enc_mask, float(0.0)) + # the resulting features will contain some randome numbers in the + # invalid rows, can suppress using the above comment + # optional: but not masking will not affect the final results + + # ******************************** transformer + + # add positional embedding + agent_enc_shuffle = ( + data["agent_enc_shuffle"] if self.agent_enc_shuffle else None + ) + tf_in_pos = self.pos_encoder( + tf_in, num_a=agent_num, agent_enc_shuffle=agent_enc_shuffle + ) # PN x B x feat + + tf_in_pos = tf_in_pos.masked_fill_(enc_mask, float(0.0)) + # the resulting features will contain some randome numbers in the + # invalid rows, can suppress using the above comment + # optional: but not masking will not affect the final results + + # transformer encoder + assert not torch.isnan(tf_in_pos).any(), "error" + data["context_enc"] = self.tf_encoder( + tf_in_pos, mask=src_mask_square, num_agent=agent_num # BH x PN x PN + ) # PN x B x feat + assert not torch.isnan(data["context_enc"]).any(), "error" + + # mask NaN row (now contained random numbers due to softmax and bias in the linear layers) + # replace random numbers in the NaN rows with 0s to avoid confusion + # here, the masking is needed, otherwise will affect the prior in the pooling + data["context_enc"] = data["context_enc"].masked_fill_( + enc_mask, float(0.0) + ) # PN x B x feat + + # ******************************** compute latent distribution + + # compute per agent context for prior + # using mean will average over a few zeros for the agents with invalid data + context_rs = data["context_enc"].view( + pre_len, agent_num, bs, self.model_dim + ) # P x N x B x feat + if self.pooling == "mean": + data["agent_context"] = torch.mean(context_rs, dim=0) # N x B x feat + else: + data["agent_context"] = torch.max(context_rs, dim=0)[0] + data["agent_context"] = data["agent_context"].view( + agent_num * bs, -1 + ) # NB x feat + + +""" Future Encoder """ + + +class FutureEncoder(nn.Module): + def __init__(self, cfg, **kwargs): + super().__init__() + self.cfg = cfg + self.context_dim = context_dim = cfg["tf_model_dim"] + self.forecast_dim = forecast_dim = cfg["forecast_dim"] + self.nz = cfg["nz"] + self.z_type = cfg["z_type"] + + self.model_dim = cfg["tf_model_dim"] + self.ff_dim = cfg["tf_ff_dim"] + self.nhead = cfg["tf_nhead"] + self.dropout = cfg["tf_dropout"] + self.nlayer = cfg["future_encoder"]["nlayer"] + self.out_mlp_dim = cfg.future_decoder.out_mlp_dim + self.input_type = cfg["fut_input_type"] + self.pooling = cfg.pooling + self.agent_enc_shuffle = cfg.agent_enc_shuffle + self.vel_heading = cfg.vel_heading + # networks + in_dim = forecast_dim * len(self.input_type) + if "map" in self.input_type: + in_dim += cfg.map_encoder.feature_dim - forecast_dim + self.input_fc = nn.Linear(in_dim, self.model_dim) + + decoder_layers = AgentFormerDecoderLayer( + {}, self.model_dim, self.nhead, self.ff_dim, self.dropout + ) + self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) + self.pos_encoder = PositionalAgentEncoding( + self.model_dim, + self.dropout, + concat=cfg["pos_concat"], + max_a_len=cfg["max_agent_len"], + use_agent_enc=cfg["use_agent_enc"], + agent_enc_learn=cfg["agent_enc_learn"], + ) + num_dist_params = ( + 2 * self.nz if self.z_type == "gaussian" else self.nz + ) # either gaussian or discrete + if self.out_mlp_dim is None: + self.q_z_net = nn.Linear(self.model_dim, num_dist_params) + else: + self.out_mlp = AFMLP(self.model_dim, self.out_mlp_dim, "relu") + self.q_z_net = nn.Linear(self.out_mlp.out_dim, num_dist_params) + # initialize + initialize_weights(self.q_z_net.modules()) + + def forward(self, data, reparam=True, temp=0.1): + fut_len, agent_num, bs = ( + data["fut_motion"].size(0), + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + pre_len = data["pre_motion"].size(0) + FN = fut_len * agent_num + PN = pre_len * agent_num + + # get input feature + traj_in = [] + for key in self.input_type: + if key == "pos": + traj_in.append(data["fut_motion"]) # F x N x B x 2 + elif key == "vel": + vel = data["fut_vel"] # F x N x B x 2 + if self.vel_heading: + vel = rotation_2d_torch(vel, -data["heading"])[0] + traj_in.append(vel) + elif key == "norm": + traj_in.append(data["fut_motion_norm"]) # F x N x B x 2 + elif key == "scene_norm": + traj_in.append(data["fut_motion_scene_norm"]) # F x N x B x 2 + elif key == "heading": + hv = ( + data["heading_vec"].unsqueeze(0).repeat_interleave(fut_len, dim=0) + ) # F x N x B x 2 + traj_in.append(hv) + elif key == "map": + map_enc = ( + data["map_enc"] + .unsqueeze(0) + .repeat((data["fut_motion"].shape[0], 1, 1)) + ) + traj_in.append(map_enc) + else: + raise ValueError("unknown input_type!") + + # ******************************** create mask for NaN + + # generate masks, mem_mask for cross attention between past and future, tgt_mask for self_attention between futures + # mem_agent_mask = data['agent_mask'].clone() # N x N + # mem_mask = generate_mask(tf_in.shape[0], data['context_enc'].shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # FN x PN + # tgt_agent_mask = data['agent_mask'].clone() # N x N + # tgt_mask = generate_mask(tf_in.shape[0], tf_in.shape[0], data['agent_num'], tgt_agent_mask).to(tf_in.device) # FN x FN + + # time-stamp based masking, i.e., not masking for a whole agents + # can only mask part of the agents who have incomplete data + fut_mask = ( + data["fut_mask"].transpose(1, 2).contiguous().view(bs, FN, 1) + ) # B x FN x 1 + pre_mask = ( + data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) + ) # B x PN x 1 + mem_mask = torch.bmm(fut_mask, pre_mask.transpose(1, 2)) # B x FN x PN + tgt_mask = torch.bmm(fut_mask, fut_mask.transpose(1, 2)) # B x FN x FN + + # due to the inverse definition in attention.py + # 0 means good, 1 means nan data + enc_mask = (1 - fut_mask.transpose(0, 1)).bool() # FN x B x 1 + mem_mask = (1 - mem_mask).bool() # B x FN x PN + tgt_mask = (1 - tgt_mask).bool() # B x FN x FN + + # expand mask to head dimensions + mem_mask = ( + mem_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * self.nhead, FN, PN) + ) # BH x FN x PN + tgt_mask = ( + tgt_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * self.nhead, FN, FN) + ) # BH x FN x FN + + # ******************************** feature encoding + + # mask NaN because even simple fc cannot handle NaN in backward pass + traj_in = torch.cat(traj_in, dim=-1) # F x N x B x feat + traj_in = traj_in.view(FN, bs, traj_in.shape[-1]) # FN x B x feat + traj_in = traj_in.masked_fill_(enc_mask, float(0)) # FN x B x feat + + # input projection + tf_in = self.input_fc(traj_in) # FN x B x feat + tf_in = tf_in.masked_fill_(enc_mask, float(0.0)) # FN x B x feat + # the resulting features will contain some randome numbers in the + # invalid rows, can suppress using the above comment + # optional: but not masking will not affect the final results + + # ******************************** transformer + + # add positional embedding + agent_enc_shuffle = ( + data["agent_enc_shuffle"] if self.agent_enc_shuffle else None + ) + tf_in_pos = self.pos_encoder( + tf_in, num_a=agent_num, agent_enc_shuffle=agent_enc_shuffle + ) # FN x B x feat + tf_in_pos = tf_in_pos.masked_fill_(enc_mask, float(0.0)) + # the resulting features will contain some randome numbers in the + # invalid rows, can suppress using the above comment + # optional: but not masking will not affect the final results + + # transformer decoder (cross attention between future and context features) + assert not torch.isnan(tf_in_pos).any(), "error" + tf_out, _ = self.tf_decoder( + tf_in_pos, # FN x B x feat + data["context_enc"], # PN x B x feat + memory_mask=mem_mask, # BH x FN x PN + tgt_mask=tgt_mask, # BH x FN x FN + num_agent=agent_num, + ) # FN x B x feat + assert not torch.isnan(tf_out).any(), "error" + + # mask NaN row (now contained random numbers due to softmax and bias in the linear layers) + # replace random numbers in the NaN rows with 0s to avoid confusion + # here, the masking is needed, otherwise will affect the posterior in the pooling + tf_out = tf_out.masked_fill_(enc_mask, float(0.0)) # FN x B x feat + + # ******************************** compute latent distribution + + # compute per agent for posterior + tf_out = tf_out.view(fut_len, agent_num, bs, self.model_dim) # F x N x B x feat + if self.pooling == "mean": + h = torch.mean(tf_out, dim=0) # N x B x feat + else: + h = torch.max(tf_out, dim=0)[0] # N x B x feat + if self.out_mlp_dim is not None: + h = self.out_mlp(h) # N x B x feat + h = h.view(agent_num * bs, -1) # NB x feat + + # ******************************** sample latent code + + # sample latent code from the posterior distribution + # each agent has a separate distribution and sample independently + q_z_params = self.q_z_net(h) # NB x 64 (contain mu and var) + if self.z_type == "gaussian": + data["q_z_dist"] = Normal(params=q_z_params) + else: + data["q_z_dist"] = Categorical(logits=q_z_params, temp=temp) + data["q_z_samp"] = ( + data["q_z_dist"].rsample().reshape(agent_num, bs, -1) + ) # N x B x 32 + + +""" Future Decoder """ + + +class FutureDecoder(nn.Module): + def __init__(self, cfg, **kwargs): + super().__init__() + self.cfg = cfg + self.ar_detach = cfg["ar_detach"] + self.context_dim = context_dim = cfg["tf_model_dim"] + self.forecast_dim = forecast_dim = cfg["forecast_dim"] + self.pred_scale = cfg["pred_scale"] + self.pred_type = cfg["pred_type"] + self.sn_out_type = cfg["sn_out_type"] + self.sn_out_heading = cfg["sn_out_heading"] + self.input_type = cfg["dec_input_type"] + self.future_frames = cfg["future_num_frames"] + self.past_frames = cfg["history_num_frames"] + self.nz = cfg["nz"] + self.z_type = cfg["z_type"] + self.model_dim = cfg["tf_model_dim"] + self.ff_dim = cfg["tf_ff_dim"] + self.nhead = cfg["tf_nhead"] + self.dropout = cfg["tf_dropout"] + self.nlayer = cfg["future_decoder"]["nlayer"] + self.out_mlp_dim = cfg.future_decoder.out_mlp_dim + self.pos_offset = cfg.pos_offset + self.agent_enc_shuffle = cfg["agent_enc_shuffle"] + self.learn_prior = cfg["learn_prior"] + # networks + if self.pred_type in ["dynamic", "dynamic_var"]: + in_dim = 6 + len(self.input_type) * forecast_dim + self.nz + + if cfg.dynamic_type == "Unicycle": + self.dyn = dynamics.Unicycle(cfg.step_time) + else: + raise Exception("not supported dynamic type") + + else: + in_dim = forecast_dim + len(self.input_type) * forecast_dim + self.nz + if "map" in self.input_type: + in_dim += cfg.map_encoder.feature_dim - forecast_dim + self.input_fc = nn.Linear(in_dim, self.model_dim) + + decoder_layers = AgentFormerDecoderLayer( + {}, self.model_dim, self.nhead, self.ff_dim, self.dropout + ) + self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) + + self.pos_encoder = PositionalAgentEncoding( + self.model_dim, + self.dropout, + concat=cfg["pos_concat"], + max_a_len=cfg["max_agent_len"], + use_agent_enc=cfg["use_agent_enc"], + agent_enc_learn=cfg["agent_enc_learn"], + ) + if self.pred_type in ["scene_norm", "vel", "pos", "dynamic"]: + outdim = forecast_dim + elif self.pred_type == "dynamic_var": + outdim = forecast_dim + 2 + if self.out_mlp_dim is None: + self.out_fc = nn.Linear(self.model_dim, outdim) + else: + in_dim = self.model_dim + self.out_mlp = AFMLP(in_dim, self.out_mlp_dim, "relu") + self.out_fc = nn.Linear(self.out_mlp.out_dim, outdim) + initialize_weights(self.out_fc.modules()) + if self.learn_prior: + num_dist_params = ( + 2 * self.nz if self.z_type == "gaussian" else self.nz + ) # either gaussian or discrete + self.p_z_net = nn.Linear(self.model_dim, num_dist_params) + initialize_weights(self.p_z_net.modules()) + + def decode_traj_ar( + self, + data, + mode, + context, + input_dict, + z, + sample_num, + need_weights=False, + cond_idx=None, + ): + # z: N x BS x 32 + + fut_len, agent_num, bs = ( + data["fut_motion"].size(0), + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + pre_len = data["pre_motion"].size(0) + FN = fut_len * agent_num + PN = pre_len * agent_num + device = data["fut_motion"].device + # get input feature, only take the current timestamp as input here + if self.pred_type == "vel": + pre_vel = input_dict["pre_vel"] + fut_vel = input_dict["fut_vel"] + dec_in = torch.cat((pre_vel[[-1]], fut_vel)) # (1+F) x N x BS x 2 + elif self.pred_type == "pos": + pre_motion = input_dict["pre_motion"] + fut_motion = input_dict["fut_motion"] + dec_in = torch.cat((pre_motion[[-1]], fut_motion), 0) # (1+F) x N X BS x 2 + elif self.pred_type == "scene_norm": + pre_motion_scene_norm = input_dict["pre_motion_scene_norm"] + fut_motion_scene_norm = input_dict["fut_motion_scene_norm"] + dec_in = torch.cat( + (pre_motion_scene_norm[[-1]], fut_motion_scene_norm), 0 + ) # (1+F) x N x BS x 2 + elif self.pred_type == "dynamic": + curr_state = input_dict["curr_state"] + pre_state_vec = input_dict["pre_state_vec"] + fut_state_vec = input_dict["fut_state_vec"] + dec_in = torch.cat( + (pre_state_vec[[-1]], fut_state_vec) + ) # (1+F) x N x BS x 6 + dec_state = [curr_state] + elif self.pred_type == "dynamic_var": + curr_state = input_dict["curr_state"] + pre_state_vec = input_dict["pre_state_vec"] + fut_state_vec = input_dict["fut_state_vec"] + dec_in = torch.cat( + (pre_state_vec[[-1]], fut_state_vec) + ) # (1+F) x N x BS x 6 + dec_state = [curr_state] + + else: + dec_in = torch.zeros([1 + fut_len, agent_num, bs * sample_num, 2]).to( + device + ) # (1+F) x N x BS x 2 + + # concatenate conditional input features with latent code + # broadcast to the sample dimension + + z_tiled = z.unsqueeze(0).repeat_interleave(1 + fut_len, 0) + + dec_in = dec_in.view( + (1 + fut_len) * agent_num, bs * sample_num, dec_in.size(-1) + ) # (1+F)N x BS x feat + in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] + + # add additional features such as the map + for key in self.input_type: + if key == "heading": + heading = data["heading_vec"].repeat_interleave( + sample_num, dim=1 + ) # N x BS x 2 + heading_tiled = heading.repeat(1 + fut_len, 1, 1) + + in_arr.append(heading_tiled) + elif key == "map": + map_enc = data["map_enc"].repeat_interleave(sample_num, 1) + map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) + in_arr.append(map_enc_tiled) + else: + raise ValueError("wrong decode input type!") + dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat + device = dec_in.device + orig_dec_in_z_list = list(torch.split(dec_in_z_orig, agent_num)) + updated_dec_in_z_list = list() + dec_in_z = dec_in_z_orig.clone() + + # dec_in_z_padded = torch.cat((dec_in_z,torch.zeros(agent_num*(fut_len-1),bs,D).to(device))) + + # mem_agent_mask = data['agent_mask'].clone() + # tgt_agent_mask = data['agent_mask'].clone() + + if self.pred_type == "dynamic_var": + logvar = list() + + # predict for each timestamps auto-regressively + for fut_index in range(fut_len): + F_tmp = fut_index + 1 + FN_tmp = F_tmp * agent_num + + # ******************************** create mask for NaN + + # agent-wise masking + # mem_mask = pred_utils.generate_mask(tf_in.shape[0], context.shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # (F)N x PN + # tgt_mask = pred_utils.generate_ar_mask(tf_in_pos.shape[0], agent_num, tgt_agent_mask).to(tf_in.device) # (F)N x (F)N + + # time-stamp-based masking + # only using the last timestamp of pre_motion, i.e., the current frame of mask + # repeat it over future frames, i.e., the assumption is that the valid objects + # to predict must have data in the current frame, this is safe since we interpolated + # data in the trajdata, i.e., objects with incomplete trajectories may have NaN in the + # beginning/end of the time window, but not in the current frame + cur_mask = ( + data["pre_mask"][:, :, [-1]] + .transpose(1, 2) + .contiguous() + .view(bs, agent_num, 1) + ) # B x N x 1 + cur_mask = cur_mask.repeat_interleave(sample_num, dim=0) # BS x N x 1 + cur_mask = cur_mask.unsqueeze(1).repeat_interleave(1 + fut_len, dim=1) + + cur_mask[:, F_tmp:] = 0 + if cond_idx is not None: + cur_mask[:, :, cond_idx] = 1 + + cur_mask = cur_mask.view(bs * sample_num, (1 + fut_len) * agent_num, 1) + + pre_mask = ( + data["pre_mask"] + .transpose(1, 2) + .contiguous() + .view(bs, PN, 1) + .repeat_interleave(sample_num, dim=0) + ) # BS x PN x 1 + + mem_mask = torch.bmm(cur_mask, pre_mask.transpose(1, 2)) # BS x (1+F)N x PN + tgt_mask = torch.bmm( + cur_mask, cur_mask.transpose(1, 2) + ) # BS x (1+F)N x (1+F)N + + # due to the inverse definition in attention.py + # 0 means good, 1 means nan data now + cur_mask = (1 - cur_mask.transpose(0, 1)).bool() # (1+F)N x BS x 1 + mem_mask = (1 - mem_mask).bool() # BS x (1+F)N x PN + tgt_mask = (1 - tgt_mask).bool() # BS x (1+F)N x (1+F)N + + # expand mask to head dimensions + mem_mask = ( + mem_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * sample_num * self.nhead, (1 + fut_len) * agent_num, PN) + ) # BSH x (1+F)N x PN + tgt_mask = ( + tgt_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view( + bs * sample_num * self.nhead, + (1 + fut_len) * agent_num, + (1 + fut_len) * agent_num, + ) + ) # BSH x (1+F)N x (1+F)N + + # ******************************** feature encoding + + # mask NaN because even simple fc cannot handle NaN in backward pass + + tf_in = dec_in_z.masked_fill_(cur_mask, float(0)) # (1+F)N x BS x feat + + # input projection + tf_in = self.input_fc( + tf_in + ) # (F)N x BS x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in = tf_in.masked_fill_(cur_mask, float(0.0)) # (1+F)N x BS x feat + + # ******************************** transformer + + # add positional encoding + agent_enc_shuffle = ( + data["agent_enc_shuffle"] if self.agent_enc_shuffle else None + ) + tf_in_pos = self.pos_encoder( + tf_in, + num_a=agent_num, + agent_enc_shuffle=agent_enc_shuffle, + t_offset=self.past_frames - 1 if self.pos_offset else 0, + ) + # (F)N x BS x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in_pos = tf_in_pos.masked_fill_(cur_mask, float(0.0)) + + # transformer decoder (between predicted steps and past context) + assert not torch.isnan(tf_in_pos).any(), "error" + + tf_out, attn_weights = self.tf_decoder( + tf_in_pos, # (F)N x BS x feat + context, # PN x BS x feat + memory_mask=mem_mask, # BSH x (F)N x PN + tgt_mask=tgt_mask, # BSH x (F)N x (F)N + num_agent=agent_num, + need_weights=need_weights, + ) + + assert not torch.isnan(tf_out).any(), "error" + # tf_out: (1+F)N x BS x feat + + # ******************************** output projection + + # convert the output feature to output dimension (x, y) + # out_tmp = tf_out.view(-1, tf_out.shape[-1]) # (F)NS x feat + if self.out_mlp_dim is not None: + out_tmp = self.out_mlp(tf_out) # (F)N x BS x feat + seq_out = self.out_fc(out_tmp) # (F)N x BS x 2 + + # denormalize data and de-rotate + if self.pred_type == "scene_norm" and self.sn_out_type in {"vel", "norm"}: + norm_motion = seq_out.view( + 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] + ) # (1+F) x N x BS x 2 + + # aggregate velocity prediction to obtain location + if self.sn_out_type == "vel": + norm_motion = torch.cumsum(norm_motion, dim=0) # (1+F) x N x BS x 2 + + # default not used + if self.sn_out_heading: + angles = data["heading"].repeat_interleave(sample_num) + norm_motion = rotation_2d_torch(norm_motion, angles)[0] + + # denormalize over the scene + # we are predicting delta with respect to the current frame of data + # will introduce NaN here since the scene_norm data in the current frame has NaN + seq_out = ( + norm_motion + pre_motion_scene_norm[[-1]] + ) # (1+F) x N x BS x 2 + dec_feat_in = seq_out.view( + (1 + fut_len) * agent_num, bs * sample_num, seq_out.shape[-1] + ) # (1+F)N x BS x 2 + elif self.pred_type in ["dynamic", "dynamic_var"]: + traj_scale = data["traj_scale"] + input_seq = TensorUtils.reshape_dimensions_single( + seq_out[..., : self.forecast_dim], 0, 1, [fut_len + 1, -1] + ).permute(1, 2, 0, 3) + + # curr_state_xyhv = torch.cat((curr_state[...,:2],curr_state[...,3:],curr_state[...,2:3]),-1) + state_seq = self.dyn.forward_dynamics(curr_state, input_seq[..., 1:, :]) + # state_seq = torch.cat((state_seq[...,:2],state_seq[...,3:],state_seq[...,2:3]),-1) + state_seq = state_seq.permute(2, 0, 1, 3) + state_seq = torch.cat((curr_state.unsqueeze(0), state_seq), 0) + yaw = state_seq[..., 3:] + vel = state_seq[..., 2:3] / traj_scale + cosyaw = torch.cos(yaw) + sinyaw = torch.sin(yaw) + dec_feat_in = TensorUtils.join_dimensions( + torch.cat( + ( + state_seq[..., :2] / traj_scale, + vel * cosyaw, + vel * sinyaw, + cosyaw, + sinyaw, + ), + -1, + ), + 0, + 2, + ) + + # ******************************** prepare for the next timestamp + + # only take the last few results for the N agents predicted in the last timestamp + if self.ar_detach: + out_in = ( + dec_feat_in[F_tmp * agent_num : (1 + F_tmp) * agent_num] + .clone() + .detach() + ) # N x BS x 2(6) + else: + out_in = dec_feat_in[ + F_tmp * agent_num : (1 + F_tmp) * agent_num + ] # N x BS x 2(6) + + # create input for the next timestamp + in_arr = [out_in, z] # z: N x BS x 32 + + for key in self.input_type: + if key == "heading": + in_arr.append(heading) # z: N x BS x 2 + elif key == "map": + in_arr.append(map_enc) + else: + raise ValueError("wrong decoder input type!") + + # combine with previous information, data in normal forward order + # i.e., newly predicted information attached in the end of features + out_in_z = torch.cat(in_arr, dim=-1) # N x BS x feat + updated_dec_in_z_list.append(out_in_z) + # import pdb + # pdb.set_trace() + curr_dec_list = ( + orig_dec_in_z_list[0:1] + + updated_dec_in_z_list + + orig_dec_in_z_list[F_tmp + 1 :] + ) + dec_in_z = torch.cat(curr_dec_list, 0) + # dec_in_z[F_tmp*agent_num:(1+F_tmp)*agent_num] = out_in_z + + # seq_out: FN x BS x 2 + seq_out = seq_out.view( + 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] + ) # 1+F x N x BS x 2 + seq_out = seq_out[1:] # F x N x BS x 2 + data[f"{mode}_seq_out"] = seq_out + + if self.pred_type == "vel": + dec_motion = torch.cumsum(seq_out, dim=0) # F x N x BS x 2 + dec_motion += pre_motion[[-1]] # F x N X BS x 2 + elif self.pred_type == "pos": + dec_motion = seq_out.clone() + elif self.pred_type == "scene_norm": + dec_motion = seq_out + data["scene_orig"].repeat_interleave( + sample_num, dim=0 + ) # F x N X BS x 2 + elif self.pred_type in ["dynamic", "dynamic_var"]: + input_seq = seq_out.permute(1, 2, 0, 3) + # curr_state_xyhv = torch.cat((curr_state[...,:2],curr_state[...,3:],curr_state[...,2:3]),-1) + state_seq = self.dyn.forward_dynamics( + curr_state, input_seq[..., : self.forecast_dim] + ) + # state_seq = torch.cat((state_seq[...,:2],state_seq[...,3:],state_seq[...,2:3]),-1) + state_seq = state_seq.permute(2, 0, 1, 3) + dec_state = state_seq + dec_motion = state_seq[..., :2] / data["traj_scale"] + data["controls"] = ( + input_seq[..., : self.forecast_dim] + .transpose(0, 2) + .contiguous() + .view(bs, sample_num, agent_num, fut_len, self.forecast_dim) + ) + else: + dec_motion = seq_out + pre_motion[[-1]] # F x N X BS x 2 + + # reshape for loss computation + dec_motion = dec_motion.transpose(0, 2).contiguous() # BS x N x F x 2 + + dec_motion = dec_motion.view( + bs, sample_num, agent_num, fut_len, dec_motion.size(-1) + ) # B x S x N x F x 2 + if self.pred_type in ["dynamic", "dynamic_var"]: + dec_state = ( + dec_state.transpose(0, 2) + .contiguous() + .view(bs, sample_num, agent_num, fut_len, dec_state.size(-1)) + ) + data[f"{mode}_dec_state"] = dec_state + if self.pred_type == "dynamic_var": + logvar = seq_out[..., self.forecast_dim : 2 * self.forecast_dim] + var = torch.exp(logvar) * data["traj_scale"] ** 2 + var = ( + var.permute(2, 1, 0, 3) + .contiguous() + .view(bs, sample_num, agent_num, fut_len, var.size(-1)) + ) + data[f"{mode}_var"] = var + + data[f"{mode}_dec_motion"] = dec_motion + if need_weights: + data["attn_weights"] = attn_weights + + def decode_traj_batch( + self, + data, + mode, + context, + input_dict, + z, + sample_num, + ): + raise NotImplementedError + + def forward( + self, + data, + mode, + sample_num=1, + autoregress=True, + z=None, + need_weights=False, + cond_idx=None, + temp=0.1, + predict=False, + ): + agent_num, bs = ( + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + + # conditional input to the decoding process + context = data["context_enc"].repeat_interleave( + sample_num, dim=1 + ) # PN x BS x feat + + pre_motion = data["pre_motion"].repeat_interleave( + sample_num, dim=2 + ) # P x N X BS x 2 + fut_motion = data["fut_motion"].repeat_interleave( + sample_num, dim=2 + ) # F x N X BS x 2 + pre_motion_scene_norm = data["pre_motion_scene_norm"].repeat_interleave( + sample_num, dim=2 + ) # P x N x BS x 2 + fut_motion_scene_norm = data["fut_motion_scene_norm"].repeat_interleave( + sample_num, dim=2 + ) # F x N x BS x 2 + input_dict = dict( + pre_motion=pre_motion, + fut_motion=fut_motion, + pre_motion_scene_norm=pre_motion_scene_norm, + fut_motion_scene_norm=fut_motion_scene_norm, + ) + if self.pred_type == "vel": + input_dict["pre_vel"] = data["pre_vel"].repeat_interleave( + sample_num, dim=2 + ) # P x N x BS x 2 + input_dict["fut_vel"] = data["fut_vel"].repeat_interleave( + sample_num, dim=2 + ) # F x N x BS x 2 + elif self.pred_type in ["dynamic", "dynamic_var"]: + traj_scale = data["traj_scale"] + pre_state = torch.cat( + ( + data["pre_motion"] * traj_scale, + torch.norm(data["pre_vel"], dim=-1, keepdim=True) * traj_scale, + data["pre_heading_raw"].transpose(0, 2).unsqueeze(-1), + ), + -1, + ) # P x N x B x 4 (unscaled) + + pre_state_vec = torch.cat( + (data["pre_motion"], data["pre_vel"], data["pre_heading_vec"]), -1 + ) # P x N x B x 6 (scaled) + fut_state_vec = torch.cat( + (data["fut_motion"], data["fut_vel"], data["fut_heading_vec"]), -1 + ) # F x N x B x 6 (scaled) + input_dict["curr_state"] = pre_state[-1].repeat_interleave( + sample_num, dim=1 + ) + input_dict["pre_state_vec"] = pre_state_vec.repeat_interleave( + sample_num, dim=2 + ) + input_dict["fut_state_vec"] = fut_state_vec.repeat_interleave( + sample_num, dim=2 + ) + + # p(z), compute prior distribution + if mode == "infer": + prior_key = "p_z_dist_infer" + else: + prior_key = "q_z_dist" if "q_z_dist" in data else "p_z_dist" + + if self.learn_prior: + p_z_params0 = self.p_z_net(data["agent_context"]) + + h = data["agent_context"].repeat_interleave(sample_num, dim=0) # NBS x feat + p_z_params = self.p_z_net(h) # NBS x 64 + if self.z_type == "gaussian": + data["p_z_dist_infer"] = Normal(params=p_z_params) + data["p_z_dist"] = Normal(params=p_z_params0) + else: + data["p_z_dist_infer"] = Categorical(logits=p_z_params, temp=temp) + data["p_z_dist"] = Categorical(logits=p_z_params0, temp=temp) + else: + if self.z_type == "gaussian": + data[prior_key] = Normal( + mu=torch.zeros(pre_motion.shape[1], self.nz).to(pre_motion.device), + logvar=torch.zeros(pre_motion.shape[1], self.nz).to( + pre_motion.device + ), + ) + else: + data[prior_key] = Categorical( + logits=torch.zeros(pre_motion.shape[1], self.nz).to( + pre_motion.device + ) + ) + + # sample latent code from the distribution + if z is None: + # use latent code z from posterior for training + if mode == "train": + z = data["q_z_samp"] # N x B x 32 + + # use latent code z from posterior for evaluating the reconstruction loss + elif mode == "recon": + z = data["q_z_dist"].mode() # NB x 32 + z = z.view(agent_num, bs, z.size(-1)) # N x B x 32 + + # use latent code z from the prior for inference + elif mode == "infer": + # dist = data["p_z_dist_infer"] if "p_z_dist_infer" in data else data["q_z_dist_infer"] + # z = dist.sample() # NBS x 32 + # import pdb + # pdb.set_trace() + + dist = ( + data["q_z_dist"] + if data["q_z_dist"] is not None + else data["p_z_dist"] + ) + if self.z_type == "gaussian": + if predict: + z = dist.pseudo_sample(sample_num) + else: + z = data["p_z_dist_infer"].sample() + D = z.shape[-1] + samples = z.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3) + mu = dist.mu.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3)[ + :, :, 0 + ] + sigma = dist.sigma.reshape(agent_num, bs, -1, D).permute( + 1, 0, 2, 3 + )[:, :, 0] + data["prob"] = self.pseudo_sample_prob( + samples, mu, sigma, data["agent_avail"] + ) + elif self.z_type == "discrete": + if predict: + z = dist.pseudo_sample(sample_num).contiguous() + else: + z = dist.rsample(sample_num).contiguous() + D = z.shape[-1] + idx = z.argmax(dim=-1) + prob_sample = torch.gather(dist.probs, -1, idx) + prob_sample = prob_sample.reshape(agent_num, bs, -1).mean(0) + prob_sample = prob_sample / prob_sample.sum(-1, keepdim=True) + data["prob"] = prob_sample + + z = z.view(agent_num, bs * sample_num, z.size(-1)) # N x BS x 32 + else: + raise ValueError("Unknown Mode!") + + # trajectory decoding + if autoregress: + self.decode_traj_ar( + data, + mode, + context, + input_dict, + z, + sample_num, + need_weights=need_weights, + cond_idx=cond_idx, + ) + # self.decode_traj_ar_orig( + # data, + # mode, + # context, + # pre_motion, + # pre_vel, + # pre_motion_scene_norm, + # z, + # sample_num, + # need_weights=need_weights, + # ) + else: + self.decode_traj_batch( + data, + mode, + context, + input_dict, + z, + sample_num, + ) + + def pseudo_sample_prob(self, sample, mu, sigma, mask): + """ + A simple K-means estimation to estimate the probability of samples + """ + bs, Na, Ns, D = sample.shape + device = sample.device + Np = Ns * 50 + particle = torch.randn([bs, Na, Np, D]).to(device) * sigma.unsqueeze( + -2 + ) + mu.unsqueeze(-2) + dis = torch.linalg.norm(sample.unsqueeze(-2) - particle.unsqueeze(-3), dim=-1) + dis = (dis * mask[..., None, None]).sum(1) + idx = torch.argmin(dis, -2) + flag = idx.unsqueeze(1) == torch.arange(Ns).view(1, Ns, 1).repeat_interleave( + bs, 0 + ).to(device) + prob = flag.sum(-1) / Np + return prob + + +class FutureARDecoder(nn.Module): + def __init__(self, cfg, **kwargs): + super().__init__() + self.cfg = cfg + self.ar_detach = cfg["ar_detach"] + self.context_dim = context_dim = cfg["tf_model_dim"] + self.forecast_dim = forecast_dim = cfg["forecast_dim"] + self.pred_scale = cfg["pred_scale"] + self.pred_type = cfg["pred_type"] + self.sn_out_type = cfg["sn_out_type"] + self.sn_out_heading = cfg["sn_out_heading"] + self.input_type = cfg["dec_input_type"] + self.future_frames = cfg["future_num_frames"] + self.past_frames = cfg["history_num_frames"] + self.z_type = cfg["z_type"] + self.nz = cfg["nz"] if self.z_type != "None" else 0 + self.model_dim = cfg["tf_model_dim"] + self.ff_dim = cfg["tf_ff_dim"] + self.nhead = cfg["tf_nhead"] + self.dropout = cfg["tf_dropout"] + self.nlayer = cfg["future_decoder"]["nlayer"] + self.out_mlp_dim = cfg.future_decoder.out_mlp_dim + self.pos_offset = cfg.pos_offset + self.agent_enc_shuffle = cfg["agent_enc_shuffle"] + self.learn_prior = cfg["learn_prior"] + # networks + assert self.pred_type == "dynamic_AR" + in_dim = 6 + len(self.input_type) * forecast_dim + self.nz + + if cfg.dynamic_type == "Unicycle": + self.dyn = dynamics.Unicycle(cfg.step_time) + else: + raise Exception("not supported dynamic type") + + if "map" in self.input_type: + in_dim += cfg.map_encoder.feature_dim - forecast_dim + self.input_fc = nn.Linear(in_dim, self.model_dim) + + decoder_layers = AgentFormerDecoderLayer( + {}, self.model_dim, self.nhead, self.ff_dim, self.dropout + ) + self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) + + self.pos_encoder = PositionalAgentEncoding( + self.model_dim, + self.dropout, + concat=cfg["pos_concat"], + max_a_len=cfg["max_agent_len"], + use_agent_enc=cfg["use_agent_enc"], + agent_enc_learn=cfg["agent_enc_learn"], + ) + if cfg.dist_type == "gaussian": + outdim = self.dyn.udim * 2 + cfg.scene_var_dim * self.dyn.udim + if cfg.output_varx and cfg.dist_obj == "state": + outdim += self.dyn.xdim + elif cfg.dist_type == "GMM": + self.GMM_M = cfg.GMM_M + outdim = (forecast_dim * 2 + cfg.scene_var_dim * forecast_dim) * self.GMM_M + if cfg.output_varx and cfg.dist_obj == "state": + outdim += self.dyn.xdim * self.GMM_M + self.GMM_pi_net = nn.Linear(self.model_dim, self.GMM_M) + + if self.out_mlp_dim is None: + self.out_fc = nn.Linear(self.model_dim, outdim) + else: + in_dim = self.model_dim + self.out_mlp = AFMLP(in_dim, self.out_mlp_dim, "relu") + self.out_fc = nn.Linear(self.out_mlp.out_dim, outdim) + initialize_weights(self.out_fc.modules()) + if self.learn_prior and self.z_type != "None": + num_dist_params = ( + 2 * self.nz if self.z_type == "gaussian" else self.nz + ) # either gaussian or discrete + self.p_z_net = nn.Linear(self.model_dim, num_dist_params) + initialize_weights(self.p_z_net.modules()) + + def decode_traj_ar( + self, + data, + mode, + context, + input_dict, + z, + sample_num, + gt_step, + need_weights=False, + ): + # z: N x BS x 32 + fut_len, agent_num, bs = ( + data["fut_motion"].size(0), + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + # assert mode=="infer" + gt_step = gt_step if gt_step < fut_len else fut_len - 1 + pre_len = data["pre_motion"].size(0) + FN = fut_len * agent_num + PN = pre_len * agent_num + # get input feature, only take the current timestamp as input here + + curr_state = input_dict["curr_state"] + pre_state_vec = input_dict["pre_state_vec"] + fut_state_vec = input_dict["fut_state_vec"] + dec_in = torch.cat((pre_state_vec[[-1]], fut_state_vec)) # (1+F) x N x BS x 6 + dec_state = [curr_state] + + # concatenate conditional input features with latent code + # broadcast to the sample dimension + + dec_in = dec_in.view( + (1 + fut_len) * agent_num, bs * sample_num, dec_in.size(-1) + ) # (1+F)N x BS x feat + if z is not None: + z_tiled = z.unsqueeze(0).repeat_interleave(1 + fut_len, 0) + in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] + else: + in_arr = [dec_in] + + # add additional features such as the map + for key in self.input_type: + if key == "heading": + heading = data["heading_vec"].repeat_interleave( + sample_num, dim=1 + ) # N x BS x 2 + heading_tiled = heading.repeat(1 + fut_len, 1, 1) + + in_arr.append(heading_tiled) + elif key == "map": + map_enc = data["map_enc"].repeat_interleave(sample_num, 1) + map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) + in_arr.append(map_enc_tiled) + else: + raise ValueError("wrong decode input type!") + dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat + orig_dec_in_z_list = list(torch.split(dec_in_z_orig, agent_num)) + updated_dec_in_z_list = list() + dec_in_z = dec_in_z_orig.clone() + + # mem_agent_mask = data['agent_mask'].clone() + # tgt_agent_mask = data['agent_mask'].clone() + + # predict for each timestamps auto-regressively + + input_pred = [ + torch.zeros( + [agent_num, bs * sample_num, self.dyn.udim], device=curr_state.device + ) + for _ in range(fut_len) + ] + state_pred = [torch.zeros_like(curr_state) for _ in range(fut_len + 1)] + state_pred[0] = curr_state + for i in range(1, 1 + gt_step): + state_pred[i] = input_dict["fut_state"][i - 1] + + for fut_index in range(gt_step, fut_len): + F_tmp = fut_index + 1 + + # ******************************** create mask for NaN + + # agent-wise masking + # mem_mask = pred_utils.generate_mask(tf_in.shape[0], context.shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # (F)N x PN + # tgt_mask = pred_utils.generate_ar_mask(tf_in_pos.shape[0], agent_num, tgt_agent_mask).to(tf_in.device) # (F)N x (F)N + + # time-stamp-based masking + # only using the last timestamp of pre_motion, i.e., the current frame of mask + # repeat it over future frames, i.e., the assumption is that the valid objects + # to predict must have data in the current frame, this is safe since we interpolated + # data in the trajdata, i.e., objects with incomplete trajectories may have NaN in the + # beginning/end of the time window, but not in the current frame + cur_mask = ( + data["pre_mask"][:, :, [-1]] + .transpose(1, 2) + .contiguous() + .view(bs, agent_num, 1) + ) # B x N x 1 + cur_mask = cur_mask.repeat_interleave(sample_num, dim=0) # BS x N x 1 + cur_mask = cur_mask.unsqueeze(1).repeat_interleave(1 + fut_len, dim=1) + + cur_mask[:, F_tmp:] = 0 + + cur_mask = cur_mask.view(bs * sample_num, (1 + fut_len) * agent_num, 1) + + pre_mask = ( + data["pre_mask"] + .transpose(1, 2) + .contiguous() + .view(bs, PN, 1) + .repeat_interleave(sample_num, dim=0) + ) # BS x PN x 1 + + mem_mask = torch.bmm(cur_mask, pre_mask.transpose(1, 2)) # BS x (1+F)N x PN + tgt_mask = torch.bmm( + cur_mask, cur_mask.transpose(1, 2) + ) # BS x (1+F)N x (1+F)N + + # due to the inverse definition in attention.py + # 0 means good, 1 means nan data now + cur_mask = (1 - cur_mask.transpose(0, 1)).bool() # (1+F)N x BS x 1 + mem_mask = (1 - mem_mask).bool() # BS x (1+F)N x PN + tgt_mask = (1 - tgt_mask).bool() # BS x (1+F)N x (1+F)N + + # expand mask to head dimensions + mem_mask = ( + mem_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * sample_num * self.nhead, (1 + fut_len) * agent_num, PN) + ) # BSH x (1+F)N x PN + tgt_mask = ( + tgt_mask.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view( + bs * sample_num * self.nhead, + (1 + fut_len) * agent_num, + (1 + fut_len) * agent_num, + ) + ) # BSH x (1+F)N x (1+F)N + + # ******************************** feature encoding + + # mask NaN because even simple fc cannot handle NaN in backward pass + + tf_in = dec_in_z.masked_fill_(cur_mask, float(0)) # (1+F)N x BS x feat + + # input projection. + tf_in = self.input_fc( + tf_in + ) # (F)N x BS x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in = tf_in.masked_fill_(cur_mask, float(0.0)) # (1+F)N x BS x feat + + # ******************************** transformer + + # add positional encoding + agent_enc_shuffle = ( + data["agent_enc_shuffle"] if self.agent_enc_shuffle else None + ) + tf_in_pos = self.pos_encoder( + tf_in, + num_a=agent_num, + agent_enc_shuffle=agent_enc_shuffle, + t_offset=self.past_frames - 1 if self.pos_offset else 0, + ) + # (F)N x BS x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in_pos = tf_in_pos.masked_fill_(cur_mask, float(0.0)) + + # transformer decoder (between predicted steps and past context) + assert not torch.isnan(tf_in_pos).any(), "error" + + tf_out, attn_weights = self.tf_decoder( + tf_in_pos, # (F)N x BS x feat + context, # PN x BS x feat + memory_mask=mem_mask, # BSH x (F)N x PN + tgt_mask=tgt_mask, # BSH x (F)N x (F)N + num_agent=agent_num, + need_weights=need_weights, + ) + + assert not torch.isnan(tf_out).any(), "error" + # tf_out: (1+F)N x BS x feat + + # ******************************** output projection + + # convert the output feature to output dimension (x, y) + # out_tmp = tf_out.view(-1, tf_out.shape[-1]) # (F)NS x feat + x_t = state_pred[fut_index] + if self.out_mlp_dim is not None: + out_tmp = self.out_mlp(tf_out) # (F)N x BS x feat + seq_out = self.out_fc(out_tmp) # (F)N x BS x 2 + seq_out_T = TensorUtils.reshape_dimensions_single( + seq_out, 0, 1, (fut_len + 1, -1) + ) + xdim, udim = self.dyn.xdim, self.dyn.udim + if self.cfg.dist_obj == "state": + if self.cfg.dist_type == "gaussian": + seq_out_t = seq_out_T[fut_index].transpose(0, 1) + mu_u = seq_out_t[..., :udim].reshape( + bs * sample_num, agent_num, udim + ) + logvar_u = seq_out_t[..., udim : 2 * udim].reshape( + bs * sample_num, agent_num, udim + ) + if self.cfg.output_varx: + logvar_x = seq_out_t[..., 2 * udim : 2 * udim + xdim].reshape( + bs * sample_num, agent_num, xdim + ) + var_x = torch.exp(logvar_x) + K = seq_out_t[..., 2 * udim + xdim :].reshape( + bs * sample_num, agent_num, udim, self.cfg.scene_var_dim + ) + else: + K = seq_out_t[..., 2 * udim :].reshape( + bs * sample_num, agent_num, udim, self.cfg.scene_var_dim + ) + var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ + None, None + ] * torch.ones( + [bs * sample_num, agent_num, xdim], device=mu_u.device + ) + dist = MADynGaussian(mu_u, torch.exp(logvar_u), var_x, K, self.dyn) + elif self.cfg.dist_type == "GMM": + seq_out_t = ( + seq_out_T[fut_index] + .transpose(0, 1) + .reshape(bs * sample_num, agent_num, self.GMM_M, -1) + .transpose(1, 2) + ) + mu_u = seq_out_t[..., : self.forecast_dim] + logvar_u = seq_out_t[..., self.forecast_dim : 2 * self.forecast_dim] + if self.cfg.output_varx: + logvar_x = seq_out_t[..., 2 * udim : 2 * udim + xdim].reshape( + bs * sample_num, self.GMM_M, agent_num, xdim + ) + var_x = torch.exp(logvar_x) + K = seq_out_t[..., 2 * udim + xdim :].reshape( + bs * sample_num, + self.GMM_M, + agent_num, + udim, + self.cfg.scene_var_dim, + ) + else: + K = seq_out_t[..., 2 * udim :].reshape( + *seq_out_t.shape[:-1], udim, self.cfg.scene_var_dim + ) + var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ + None, None, None + ] * torch.ones( + [bs * sample_num, self.GMM_M, agent_num, self.dyn.xdim], + device=mu_u.device, + ) + tf_feature_pooled = tf_out.reshape( + [fut_len + 1, agent_num, bs * sample_num, -1] + )[fut_index].max(0)[0] + logpi = self.GMM_pi_net(tf_feature_pooled) + pi = torch.softmax(logpi, dim=-1) + dist = MADynGMM(mu_u, torch.exp(logvar_u), var_x, K, pi, self.dyn) + + # mu_u = seq_out_T[...,:self.forecast_dim] + # logvar_u = seq_out_T[...,self.forecast_dim:2*self.forecast_dim] + # scene_var_M = seq_out_T[...,2*self.forecast_dim:].reshape(*seq_out_T.shape[:-1],self.forecast_dim,self.cfg.scene_var_dim) + # var_u = torch.exp(logvar_u) + # scene_noise = torch.randn(*scene_var_M.shape[1:-2],self.cfg.scene_var_dim).to(mu_u.device) + + # u_t_sample = mu_u[fut_index]+torch.randn_like(mu_u[fut_index])*torch.sqrt(var_u[fut_index]) + (scene_var_M[fut_index]@scene_noise.unsqueeze(-1)).squeeze(-1) + + # u_t_sample = u_t_sample.squeeze(-2) + # input_pred[fut_index] = u_t_sample + + # xp = self.dyn.step(x_t,u_t_sample) + xp = ( + dist.rsample( + x_t.transpose(0, 1).reshape(bs * sample_num, agent_num, -1), 1 + ) + .squeeze(1) + .transpose(0, 1) + ) + + elif self.cfg.dist_obj == "input": + if self.cfg.dist_type == "gaussian": + seq_out_t = seq_out_T[fut_index].transpose(0, 1) + mu_u = seq_out_t[..., :udim].reshape( + bs * sample_num, agent_num, udim + ) + logvar_u = seq_out_t[..., udim : 2 * udim].reshape( + bs * sample_num, agent_num, udim + ) + K = seq_out_t[..., 2 * udim :].reshape( + bs * sample_num, agent_num, udim, self.cfg.scene_var_dim + ) + dist = MAGaussian(mu_u, torch.exp(logvar_u), K) + + elif self.cfg.dist_type == "GMM": + seq_out_t = ( + seq_out_T[fut_index] + .transpose(0, 1) + .reshape(bs * sample_num, agent_num, self.GMM_M, -1) + .transpose(1, 2) + ) + mu_u = seq_out_t[..., : self.forecast_dim] + logvar_u = seq_out_t[..., self.forecast_dim : 2 * self.forecast_dim] + K = seq_out_t[..., 2 * self.forecast_dim :].reshape( + *seq_out_t.shape[:-1], self.forecast_dim, self.cfg.scene_var_dim + ) + tf_feature_pooled = tf_out.reshape( + [fut_len + 1, agent_num, bs * sample_num, -1] + )[fut_index].max(0)[0] + logpi = self.GMM_pi_net(tf_feature_pooled) + pi = torch.softmax(logpi, dim=-1) + dist = MAGMM(mu_u, torch.exp(logvar_u), K, pi) + + up = dist.rsample(1).squeeze(1).transpose(0, 1) + xp = self.dyn.step(x_t, up) + # denormalize data and de-rotate + state_pred[fut_index + 1] = xp + traj_scale = data["traj_scale"] + state_seq = torch.stack(state_pred, 0).clone() + yaw = state_seq[..., 3:] + vel = state_seq[..., 2:3] / traj_scale + cosyaw = torch.cos(yaw) + sinyaw = torch.sin(yaw) + dec_feat_in = TensorUtils.join_dimensions( + torch.cat( + ( + state_seq[..., :2] / traj_scale, + vel * cosyaw, + vel * sinyaw, + cosyaw, + sinyaw, + ), + -1, + ), + 0, + 2, + ) + + # ******************************** prepare for the next timestamp + + # only take the last few results for the N agents predicted in the last timestamp + if self.ar_detach: + out_in = dec_feat_in[ + F_tmp * agent_num : (1 + F_tmp) * agent_num + ].detach() # N x BS x 2(6) + else: + out_in = dec_feat_in[ + F_tmp * agent_num : (1 + F_tmp) * agent_num + ] # N x BS x 2(6) + + # create input for the next timestamp + if z is not None: + in_arr = [out_in, z] # z: N x BS x 32 + else: + in_arr = [out_in] + + for key in self.input_type: + if key == "heading": + in_arr.append(heading) # z: N x BS x 2 + elif key == "map": + in_arr.append(map_enc) + else: + raise ValueError("wrong decoder input type!") + + # combine with previous information, data in normal forward order + # i.e., newly predicted information attached in the end of features + out_in_z = torch.cat(in_arr, dim=-1) # N x BS x feat + updated_dec_in_z_list.append(out_in_z) + + curr_dec_list = ( + orig_dec_in_z_list[: gt_step + 1] + + updated_dec_in_z_list + + orig_dec_in_z_list[F_tmp + 1 :] + ) + dec_in_z = torch.cat(curr_dec_list, 0) + del updated_dec_in_z_list, orig_dec_in_z_list + # seq_out: FN x BS x 2 + seq_out = seq_out.view( + 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] + ) # 1+F x N x BS x 2 + seq_out = seq_out[1:] # F x N x BS x 2 + data[f"{mode}_seq_out"] = seq_out + + state_pred = torch.stack(state_pred, 0) + input_pred = torch.stack(input_pred, 0) + + dec_state = state_pred[1:] + dec_motion = state_pred[1:, ..., :2] / data["traj_scale"] + data["controls"] = input_pred[..., : self.dyn.udim].permute(1, 2, 0, 3) + + # reshape for loss computation + dec_motion = dec_motion.transpose(0, 2).contiguous() # BS x N x F x 2 + dec_motion = dec_motion.view( + bs, sample_num, agent_num, fut_len, dec_motion.size(-1) + ) # B x S x N x F x 2 + dec_state = ( + dec_state.transpose(0, 2) + .contiguous() + .view(bs, sample_num, agent_num, fut_len, dec_state.size(-1)) + ) + data[f"{mode}_dec_state"] = dec_state + + data[f"{mode}_dec_motion"] = dec_motion + if need_weights: + data["attn_weights"] = attn_weights + + def calc_traj_likelihood( + self, + data, + context, + input_dict, + z, + ): + # z: N x BS x 32 + fut_len, agent_num, bs = ( + data["fut_motion"].size(0), + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + pre_len = data["pre_motion"].size(0) + FN = fut_len * agent_num + PN = pre_len * agent_num + # get input feature, only take the current timestamp as input here + + curr_state = input_dict["curr_state"][:, :bs] + pre_state_vec = input_dict["pre_state_vec"][:, :, :bs] + fut_state_vec = input_dict["fut_state_vec"][:, :, :bs] + dec_in = torch.cat((pre_state_vec[[-1]], fut_state_vec)) # (1+F) x N x B x 6 + dec_state = [curr_state] + + # concatenate conditional input features with latent code + # broadcast to the sample dimension + + dec_in = dec_in.view( + (1 + fut_len) * agent_num, bs, dec_in.size(-1) + ) # (1+F)N x B x feat + if z is not None: + z_tiled = z[:, :bs].unsqueeze(0).repeat_interleave(1 + fut_len, 0) + in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] + else: + in_arr = [dec_in] + + # add additional features such as the map + for key in self.input_type: + if key == "heading": + heading = data["heading_vec"] # N x B x 2 + heading_tiled = heading.repeat(1 + fut_len, 1, 1) + + in_arr.append(heading_tiled) + elif key == "map": + map_enc = data["map_enc"] + map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) + in_arr.append(map_enc_tiled) + else: + raise ValueError("wrong decode input type!") + dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat + dec_in_z = dec_in_z_orig.clone() + + # predict for each timestamps auto-regressively + + state_seq = torch.cat( + [curr_state.unsqueeze(0), input_dict["fut_state"][:, :, :bs]], 0 + ) + + cur_mask = ( + data["pre_mask"][:, :, [-1]] + .transpose(1, 2) + .contiguous() + .view(bs, agent_num, 1) + ).detach() # B x N x 1 + cur_mask = cur_mask.unsqueeze(1).repeat_interleave( + 1 + fut_len, dim=1 + ) # B x (1+F) x Na x 1 + cur_mask_tiled = cur_mask.unsqueeze(1).repeat_interleave( + fut_len, dim=1 + ) # B x F x (1+F) x Na x 1 + fut_mask = torch.tril(torch.ones([fut_len, fut_len + 1]), 0).to(cur_mask.device) + cur_mask_tiled *= fut_mask[None, :, :, None, None] + cur_mask_tiled = cur_mask_tiled.view(-1, (1 + fut_len) * agent_num, 1) + pre_mask = ( + data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) + ).detach() # BS x PN x 1 + pre_mask_tiled = pre_mask.repeat_interleave(fut_len, dim=0) # BSF x PN x 1 + mem_mask_tiled = torch.bmm( + cur_mask_tiled, pre_mask_tiled.transpose(1, 2) + ) # BSF x (1+F)N x PN + tgt_mask_tiled = torch.bmm( + cur_mask_tiled, cur_mask_tiled.transpose(1, 2) + ) # BSF x (1+F)N x (1+F)N + # due to the inverse definition in attention.py + # 0 means good, 1 means nan data now + cur_mask_tiled = (1 - cur_mask_tiled.transpose(0, 1)).bool() # (1+F)N x BSF x 1 + mem_mask_tiled = (1 - mem_mask_tiled).bool() # BSF x (1+F)N x PN + tgt_mask_tiled = (1 - tgt_mask_tiled).bool() # BSF x (1+F)N x (1+F)N + + # expand mask to head dimensions + mem_mask_tiled = ( + mem_mask_tiled.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view(bs * fut_len * self.nhead, (1 + fut_len) * agent_num, PN) + ) # BSFH x (1+F)N x PN + tgt_mask_tiled = ( + tgt_mask_tiled.unsqueeze(1) + .repeat_interleave(self.nhead, dim=1) + .view( + bs * fut_len * self.nhead, + (1 + fut_len) * agent_num, + (1 + fut_len) * agent_num, + ) + ) # BSFH x (1+F)N x (1+F)N + tf_in_tiled = dec_in_z.repeat_interleave(fut_len, 1) * torch.logical_not( + cur_mask_tiled + ) + # input projection + tf_in_tiled = self.input_fc( + tf_in_tiled + ) # (F)N x BSF x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in_tiled = tf_in_tiled.masked_fill_( + cur_mask_tiled, float(0.0) + ) # (1+F)N x BS x feat + + # ******************************** transformer + + # add positional encoding + agent_enc_shuffle = ( + data["agent_enc_shuffle"] if self.agent_enc_shuffle else None + ) + tf_in_pos_tiled = self.pos_encoder( + tf_in_tiled, + num_a=agent_num, + agent_enc_shuffle=agent_enc_shuffle, + t_offset=self.past_frames - 1 if self.pos_offset else 0, + ) + # (F)N x BSF x feat, F is increamentally increased + + # optional: not masking will not affect the final results + # just to suppress some random numbers generated by linear layer's bias + # for cleaner printing, but these random numbers are not used later due to masking + tf_in_pos_tiled = tf_in_pos_tiled.masked_fill_(cur_mask_tiled, float(0.0)) + + # transformer decoder (between predicted steps and past context) + assert not torch.isnan(tf_in_pos_tiled).any(), "error" + context_tiled = context[:, :bs].repeat_interleave(fut_len, 1) + tf_out_tiled, _ = self.tf_decoder( + tf_in_pos_tiled, # (F)N x BSF x feat + context_tiled, # PN x BSF x feat + memory_mask=mem_mask_tiled, # BH x (F)N x PN + tgt_mask=tgt_mask_tiled, # BH x (F)N x (F)N + num_agent=agent_num, + need_weights=False, + ) + + assert not torch.isnan(tf_out_tiled).any(), "error" + # tf_out: (1+F)N x BSF x feat + + # ******************************** output projection + + # convert the output feature to output dimension (x, y) + if self.out_mlp_dim is not None: + out_tmp_tiled = self.out_mlp(tf_out_tiled) # (F)N x BSF x feat + seq_out_tiled = self.out_fc(out_tmp_tiled) + # select the diagonal of the output + seq_out_tiled = seq_out_tiled.reshape([fut_len + 1, agent_num, bs, fut_len, -1]) + seq_out_diag = torch.diagonal(seq_out_tiled, dim1=0, dim2=3).permute( + 1, 3, 0, 2 + ) # bs x fut_len x agent_num x dim + seq_out_diag = seq_out_diag.reshape(bs * fut_len, agent_num, -1) + mask = ( + data["fut_mask"] + .transpose(1, 2)[..., None] + .reshape(bs * fut_len, agent_num, -1) + ) + seq_out_diag *= mask + xdim, udim = self.dyn.xdim, self.dyn.udim + if self.cfg.dist_obj == "state": + if self.cfg.dist_type == "gaussian": + mu_u = seq_out_diag[..., :udim] + logvar_u = seq_out_diag[..., udim : 2 * udim] + if self.cfg.output_varx: + logvar_x = seq_out_diag[..., 2 * udim : 2 * udim + xdim] + var_x = torch.exp(logvar_x) + scene_noise_M = seq_out_diag[..., 2 * udim + xdim :].reshape( + *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim + ) + else: + var_x = ( + torch.tensor(self.cfg.min_var_x, device=mu_u.device)[None, None] + * torch.ones( + [bs * fut_len, agent_num, self.dyn.xdim], device=mu_u.device + ) + * mask + ) + scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( + *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim + ) + K = scene_noise_M.reshape( + bs * fut_len, agent_num, -1, self.cfg.scene_var_dim + ) + dist = MADynGaussian(mu_u, torch.exp(logvar_u), var_x, K, self.dyn) + elif self.cfg.dist_type == "GMM": + seq_out_diag = seq_out_diag.reshape( + bs * fut_len, agent_num, self.GMM_M, -1 + ).transpose(1, 2) + mu_u = seq_out_diag[..., : self.forecast_dim] + logvar_u = seq_out_diag[..., self.forecast_dim : 2 * self.forecast_dim] + if self.cfg.output_varx: + logvar_x = seq_out_diag[..., 2 * udim : 2 * udim + xdim] + var_x = torch.exp(logvar_x) + scene_noise_M = seq_out_diag[..., 2 * udim + xdim :].reshape( + *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim + ) + else: + scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( + *seq_out_diag.shape[:-1], + self.forecast_dim, + self.cfg.scene_var_dim, + ) + var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ + None, None, None + ] * torch.ones( + [bs * fut_len, self.GMM_M, agent_num, self.dyn.xdim], + device=mu_u.device, + ) + K = scene_noise_M + + tf_feature = tf_out_tiled.reshape( + [fut_len + 1, agent_num, bs, fut_len, -1] + ) + tf_feature_diag = ( + torch.diagonal(tf_feature, dim1=0, dim2=3) + .permute(1, 3, 0, 2) + .reshape(bs * fut_len, agent_num, -1) + ) + tf_feature_pooled = tf_feature_diag.max(1)[0] + logpi = self.GMM_pi_net(tf_feature_pooled) + pi = torch.softmax(logpi, -1) + dist = MADynGMM(mu_u, torch.exp(logvar_u), var_x, K, pi, self.dyn) + xp = state_seq[1:].permute(2, 0, 1, 3).reshape(bs * fut_len, agent_num, -1) + x0 = ( + state_seq[:fut_len] + .permute(2, 0, 1, 3) + .reshape(bs * fut_len, agent_num, -1) + ) + + log_prob = dist.get_log_likelihood(x0, xp, mask).reshape(bs, fut_len) + elif self.cfg.dist_obj == "input": + if self.cfg.dist_type == "gaussian": + mu_u = seq_out_diag[..., :udim] + logvar_u = seq_out_diag[..., udim : 2 * udim] + scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( + *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim + ) + K = scene_noise_M.reshape( + bs * fut_len, agent_num, -1, self.cfg.scene_var_dim + ) + dist = MAGaussian(mu_u, torch.exp(logvar_u), K) + elif self.cfg.dist_type == "GMM": + seq_out_diag = seq_out_diag.reshape( + bs * fut_len, agent_num, self.GMM_M, -1 + ).transpose(1, 2) + mu_u = seq_out_diag[..., :udim] + logvar_u = seq_out_diag[..., udim : 2 * udim] + K = seq_out_diag[..., 2 * udim :].reshape( + *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim + ) + tf_feature = tf_out_tiled.reshape( + [fut_len + 1, agent_num, bs, fut_len, -1] + ) + tf_feature_diag = ( + torch.diagonal(tf_feature, dim1=0, dim2=3) + .permute(1, 3, 0, 2) + .reshape(bs * fut_len, agent_num, -1) + ) + tf_feature_pooled = tf_feature_diag.max(1)[0] + logpi = self.GMM_pi_net(tf_feature_pooled) + pi = torch.softmax(logpi, -1) + dist = MAGMM(mu_u, torch.exp(logvar_u), K, pi) + + xp = state_seq[1:].permute(2, 0, 1, 3).reshape(bs * fut_len, agent_num, -1) + x0 = ( + state_seq[:fut_len] + .permute(2, 0, 1, 3) + .reshape(bs * fut_len, agent_num, -1) + ) + up = self.dyn.inverse_dyn(x0, xp) + log_prob = dist.get_log_likelihood(up, mask) + + data["log_prob"] = log_prob + data["dist"] = dist + + def forward( + self, + data, + mode, + sample_num=1, + autoregress=True, + z=None, + need_weights=False, + cond_idx=None, + temp=0.1, + predict=False, + gt_step=0, + ): + agent_num, bs = ( + data["fut_motion"].size(1), + data["fut_motion"].size(2), + ) + if mode == "train": + assert sample_num == 1 + # conditional input to the decoding process + context = data["context_enc"].repeat_interleave( + sample_num, dim=1 + ) # PN x BS x feat + + pre_motion = data["pre_motion"].repeat_interleave( + sample_num, dim=2 + ) # P x N X BS x 2 + fut_motion = data["fut_motion"].repeat_interleave( + sample_num, dim=2 + ) # F x N X BS x 2 + pre_motion_scene_norm = data["pre_motion_scene_norm"].repeat_interleave( + sample_num, dim=2 + ) # P x N x BS x 2 + fut_motion_scene_norm = data["fut_motion_scene_norm"].repeat_interleave( + sample_num, dim=2 + ) # F x N x BS x 2 + input_dict = dict( + pre_motion=pre_motion, + fut_motion=fut_motion, + pre_motion_scene_norm=pre_motion_scene_norm, + fut_motion_scene_norm=fut_motion_scene_norm, + ) + if self.pred_type == "vel": + input_dict["pre_vel"] = data["pre_vel"].repeat_interleave( + sample_num, dim=2 + ) # P x N x BS x 2 + input_dict["fut_vel"] = data["fut_vel"].repeat_interleave( + sample_num, dim=2 + ) # F x N x BS x 2 + elif self.pred_type in ["dynamic", "dynamic_var", "dynamic_AR"]: + traj_scale = data["traj_scale"] + pre_state = torch.cat( + ( + data["pre_motion"] * traj_scale, + torch.norm(data["pre_vel"], dim=-1, keepdim=True) * traj_scale, + data["pre_heading_raw"].transpose(0, 2).unsqueeze(-1), + ), + -1, + ) # P x N x B x 4 (unscaled) + fut_state = torch.cat( + ( + data["fut_motion"] * traj_scale, + torch.norm(data["fut_vel"], dim=-1, keepdim=True) * traj_scale, + data["fut_heading_raw"].transpose(0, 2).unsqueeze(-1), + ), + -1, + ) + + pre_state_vec = torch.cat( + (data["pre_motion"], data["pre_vel"], data["pre_heading_vec"]), -1 + ) # P x N x B x 6 (scaled) + fut_state_vec = torch.cat( + (data["fut_motion"], data["fut_vel"], data["fut_heading_vec"]), -1 + ) # F x N x B x 6 (scaled) + input_dict["curr_state"] = pre_state[-1].repeat_interleave( + sample_num, dim=1 + ) + input_dict["pre_state"] = pre_state.repeat_interleave(sample_num, dim=2) + input_dict["fut_state"] = fut_state.repeat_interleave(sample_num, dim=2) + input_dict["pre_state_vec"] = pre_state_vec.repeat_interleave( + sample_num, dim=2 + ) + input_dict["fut_state_vec"] = fut_state_vec.repeat_interleave( + sample_num, dim=2 + ) + + # p(z), compute prior distribution + if self.z_type != "None": + if mode == "infer": + prior_key = "p_z_dist_infer" + else: + prior_key = "q_z_dist" if "q_z_dist" in data else "p_z_dist" + + if self.learn_prior: + p_z_params0 = self.p_z_net(data["agent_context"]) + + h = data["agent_context"].repeat_interleave( + sample_num, dim=0 + ) # NBS x feat + p_z_params = self.p_z_net(h) # NBS x 64 + if self.z_type == "gaussian": + data["p_z_dist_infer"] = Normal(params=p_z_params) + data["p_z_dist"] = Normal(params=p_z_params0) + else: + data["p_z_dist_infer"] = Categorical(logits=p_z_params, temp=temp) + data["p_z_dist"] = Categorical(logits=p_z_params0, temp=temp) + else: + if self.z_type == "gaussian": + data[prior_key] = Normal( + mu=torch.zeros(pre_motion.shape[1], self.nz).to( + pre_motion.device + ), + logvar=torch.zeros(pre_motion.shape[1], self.nz).to( + pre_motion.device + ), + ) + else: + data[prior_key] = Categorical( + logits=torch.zeros(pre_motion.shape[1], self.nz).to( + pre_motion.device + ) + ) + + # sample latent code from the distribution + if z is None: + # use latent code z from posterior for training + + # use latent code z from posterior for evaluating the reconstruction loss + if mode == "recon": + z = data["q_z_dist"].mode() # NB x 32 + z = z.view(agent_num, bs, z.size(-1)) # N x B x 32 + + # use latent code z from the prior for inference + elif mode in ["infer", "train"]: + # dist = data["p_z_dist_infer"] if "p_z_dist_infer" in data else data["q_z_dist_infer"] + # z = dist.sample() # NBS x 32 + # import pdb + # pdb.set_trace() + + dist = ( + data["q_z_dist"] + if data["q_z_dist"] is not None + else data["p_z_dist"] + ) + if self.z_type == "gaussian": + if predict: + z = dist.pseudo_sample(sample_num) + else: + z = data["p_z_dist_infer"].sample() + D = z.shape[-1] + samples = z.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3) + mu = dist.mu.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3)[ + :, :, 0 + ] + sigma = dist.sigma.reshape(agent_num, bs, -1, D).permute( + 1, 0, 2, 3 + )[:, :, 0] + data["prob"] = self.pseudo_sample_prob( + samples, mu, sigma, data["agent_avail"] + ) + elif self.z_type == "discrete": + if predict: + z = dist.pseudo_sample(sample_num).contiguous() + else: + z = dist.rsample(sample_num).contiguous() + D = z.shape[-1] + idx = z.argmax(dim=-1) + prob_sample = torch.gather(dist.probs, -1, idx) + prob_sample = prob_sample.reshape(agent_num, bs, -1).mean(0) + prob_sample = prob_sample / prob_sample.sum(-1, keepdim=True) + data["prob"] = prob_sample + + z = z.view(agent_num, bs * sample_num, z.size(-1)) # N x BS x 32 + else: + raise ValueError("Unknown Mode!") + else: + z = None + + # trajectory decoding + # if mode=="train": + # self.calc_traj_likelihood(data, context, input_dict, z) + # elif mode=="infer": + # self.calc_traj_likelihood(data, context, input_dict, z) + # self.decode_traj_ar( + # data, + # mode, + # context, + # input_dict, + # z, + # sample_num, + # gt_step = gt_step, + # need_weights=need_weights, + # ) + + # else: + # raise NotImplementedError + self.calc_traj_likelihood(data, context, input_dict, z) + self.decode_traj_ar( + data, + mode, + context, + input_dict, + z, + sample_num, + gt_step=gt_step, + need_weights=need_weights, + ) + + def pseudo_sample_prob(self, sample, mu, sigma, mask): + """ + A simple K-means estimation to estimate the probability of samples + """ + bs, Na, Ns, D = sample.shape + device = sample.device + Np = Ns * 50 + particle = torch.randn([bs, Na, Np, D]).to(device) * sigma.unsqueeze( + -2 + ) + mu.unsqueeze(-2) + dis = torch.linalg.norm(sample.unsqueeze(-2) - particle.unsqueeze(-3), dim=-1) + dis = (dis * mask[..., None, None]).sum(1) + idx = torch.argmin(dis, -2) + flag = idx.unsqueeze(1) == torch.arange(Ns).view(1, Ns, 1).repeat_interleave( + bs, 0 + ).to(device) + prob = flag.sum(-1) / Np + return prob + + +""" AgentFormer """ + + +class AgentFormer(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + input_type = cfg.input_type + pred_type = cfg.pred_type + if type(input_type) == str: + input_type = [input_type] + fut_input_type = cfg.fut_input_type + dec_input_type = cfg.dec_input_type + + self.use_map = cfg.use_map + self.rand_rot_scene = cfg.rand_rot_scene + self.discrete_rot = cfg.discrete_rot + self.map_global_rot = cfg.map_global_rot + self.ar_train = cfg.ar_train + self.max_train_agent = cfg.max_train_agent + self.loss_cfg = cfg.loss_cfg + self.param_annealers = nn.ModuleList() + self.z_type = cfg.z_type + if self.z_type == "discrete": + z_tau_annealer = ExpParamAnnealer( + cfg.z_tau.start, cfg.z_tau.finish, cfg.z_tau.decay + ) + self.param_annealers.append(z_tau_annealer) + self.z_tau_annealer = z_tau_annealer + + self.ego_conditioning = cfg.ego_conditioning + self.step_time = cfg.step_time + self.dyn = dynamics.Unicycle(cfg.step_time) + self.DoubleIntegrator = dynamics.DoubleIntegrator(cfg.step_time) + if "perturb" in cfg and cfg.perturb.enabled: + self.N_pert = cfg.perturb.N_pert + theta = cfg.perturb.OU.theta + sigma = cfg.perturb.OU.sigma + scale = torch.tensor(cfg.perturb.OU.scale) + self.pert = DynOrnsteinUhlenbeckPerturbation( + theta * torch.ones(self.dyn.udim), sigma * scale, self.dyn + ) + else: + self.N_pert = 0 + self.pert = None + if "stage" in cfg: + assert cfg.stage * cfg.num_frames_per_stage <= cfg.future_num_frames + self.stage = cfg.stage + self.num_frames_per_stage = cfg.num_frames_per_stage + else: + self.stage = 1 + self.num_frames_per_stage = cfg.future_num_frames + + # save all computed variables + self.data = dict() + + # map encoder + if self.use_map: + self.map_encoder = base_models.RasterizedMapEncoder( + model_arch=cfg.map_encoder.model_architecture, + input_image_shape=cfg.map_encoder.image_shape, + feature_dim=cfg.map_encoder.feature_dim, + use_spatial_softmax=cfg.map_encoder.spatial_softmax.enabled, + spatial_softmax_kwargs=cfg.map_encoder.spatial_softmax.kwargs, + ) + + # models + self.context_encoder = ContextEncoder(cfg) + self.future_encoder = FutureEncoder(cfg) + self.future_decoder = FutureDecoder(cfg) + + def set_data(self, batch, stage=0): + device = batch["pre_motion_raw"].device + self.data[stage] = batch + self.data[stage]["step_time"] = self.step_time + bs, Na = batch["pre_motion_raw"].shape[:2] + self.data[stage]["pre_motion"] = ( + batch["pre_motion_raw"].to(device).transpose(0, 2).contiguous() + ) # P x N x B x 2 + self.data[stage]["fut_motion"] = ( + batch["fut_motion_raw"].to(device).transpose(0, 2).contiguous() + ) # F x N x B x 2 + + # compute the origin of the current scene, i.e., the center + # of the agents' location in the current frame + self.data[stage]["scene_orig"] = torch.nanmean( + self.data[stage]["pre_motion"][-1], dim=0 + ) # B x 2 + + # normalize the scene with respect to the center location + # optionally, also rotate the scene for augmentation + if self.rand_rot_scene and self.training: + # below cannot be fixed in seed, causing reproducibility issue + if self.discrete_rot: + theta = torch.randint(high=24, size=(1,)).to(device) * (np.pi / 12) + else: + theta = torch.rand(1).to(device) * np.pi * 2 # [0, 2*pi], full circle + + for key in ["pre_motion", "fut_motion"]: + ( + self.data[stage][f"{key}"], + self.data[stage][f"{key}_scene_norm"], + ) = rotation_2d_torch( + self.data[stage][key], theta, self.data[stage]["scene_orig"] + ) + if self.data[stage]["heading"] is not None: + self.data[stage]["heading"] += theta # B x N + else: + theta = torch.zeros(1).to(device) + + # normalize per scene + for key in ["pre_motion", "fut_motion"]: # (F or P) x N x B x 2 + self.data[stage][f"{key}_scene_norm"] = ( + self.data[stage][key] - self.data[stage]["scene_orig"] + ) + + # normalize pos per agent + self.data[stage]["cur_motion"] = self.data[stage]["pre_motion"][ + [-1] + ] # 1 x N x B x 2 + self.data[stage]["pre_motion_norm"] = ( + self.data[stage]["pre_motion"][:-1] + - self.data[stage]["cur_motion"] # P x N x B x 2 + ) + self.data[stage]["fut_motion_norm"] = ( + self.data[stage]["fut_motion"] - self.data[stage]["cur_motion"] + ) # F x N x B x 2 + + # vectorize heading + if self.data[stage]["heading"] is not None: + self.data[stage]["heading_vec"] = torch.stack( + [ + torch.cos(self.data[stage]["heading"]), + torch.sin(self.data[stage]["heading"]), + ], + dim=-1, + ).transpose(0, 1) + # N x B x 2 + self.data[stage]["pre_heading_vec"] = torch.stack( + [ + torch.cos(self.data[stage]["pre_heading_raw"]), + torch.sin(self.data[stage]["pre_heading_raw"]), + ], + dim=-1, + ).transpose(0, 2) + # P x N x B x 2 + + self.data[stage]["fut_heading_vec"] = torch.stack( + [ + torch.cos(self.data[stage]["fut_heading_raw"]), + torch.sin(self.data[stage]["fut_heading_raw"]), + ], + dim=-1, + ).transpose(0, 2) + # F x N x B x 2 + + # agent shuffling, default not shuffling + if self.training and self.cfg["agent_enc_shuffle"]: + self.data[stage]["agent_enc_shuffle"] = torch.randperm( + self.cfg["max_agent_len"] + )[: self.data[stage]["agent_num"]].to(device) + else: + self.data[stage]["agent_enc_shuffle"] = None + + # mask between pairwse agents, such as diable connection for a pair of agents + # that are far away from each other, currently not used, i.e., assuming all connections + conn_dist = self.cfg.conn_dist + cur_motion = self.data[stage]["cur_motion"][0] + if conn_dist < 1000.0: + threshold = conn_dist / self.cfg.traj_scale + pdist = F.pdist(cur_motion) + D = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) + D[np.triu_indices(cur_motion.shape[0], 1)] = pdist + D += D.T + mask = torch.zeros_like(D) + mask[D > threshold] = float("-inf") + else: + mask = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) + self.data[stage][ + "agent_mask" + ] = mask # N x N, all zeros now, i.e., fully-connected + + def step_annealer(self): + for anl in self.param_annealers: + anl.step() + + def convert_data(self, batch, cond_traj=None, predict=False): + data = defaultdict(lambda: None) + ego_traj = torch.cat((batch["fut_pos"][:, 0], batch["fut_yaw"][:, 0]), -1) + external_cond = True if cond_traj is not None else False + if cond_traj is None and not predict: + if self.pert is not None: + # always perturb the ego trajectory + + ego_traj_tiled = ego_traj.repeat_interleave(self.N_pert, 0) + avail = batch["fut_mask"][:, 0].repeat_interleave(self.N_pert, 0) + pert_dict = self.pert.perturb( + dict( + fut_pos=ego_traj_tiled[..., :2], + fut_yaw=ego_traj_tiled[..., 2:], + fut_mask=avail, + step_time=self.step_time, + ) + ) + pert_ego_positions = pert_dict["fut_pos"] + pert_ego_trajectories = torch.cat( + (pert_dict["fut_pos"], pert_dict["fut_yaw"]), -1 + ) + + cond_traj = torch.cat( + ( + ego_traj.unsqueeze(1), + TensorUtils.reshape_dimensions_single( + pert_ego_trajectories, 0, 1, [-1, self.N_pert] + ), + ), + 1, + ) + else: + cond_traj = ego_traj.unsqueeze(1) + + device = batch["hist_pos"].device + bs = batch["hist_yaw"].shape[0] + data["heading"] = batch["hist_yaw"][:, :, -1, 0].to(device) # B x N + data["pre_heading_raw"] = batch["hist_yaw"][..., 0].to(device) # B x N x P + data["fut_heading_full"] = batch["fut_yaw"][..., 0].to(device) + data["fut_heading_raw"] = data["fut_heading_full"][ + ..., : self.num_frames_per_stage + ] # B x N x F + traj_scale = self.cfg.traj_scale + data["traj_scale"] = traj_scale + # AgentFormer uses the x/y inputs, i.e., the first two dimensions + data["pre_motion_raw"] = (batch["hist_pos"] / traj_scale).to( + device + ) # B x N x P x 2 + data["fut_motion_full"] = (batch["fut_pos"] / traj_scale).to(device) + data["fut_motion_raw"] = ( + batch["fut_pos"][:, :, : self.num_frames_per_stage] / traj_scale + ).to( + device + ) # B x N x F x 2 + + data["pre_mask"] = ( + batch["hist_mask"].float().to(device) + ) # B x N x P # B x N x F x 2 + data["fut_mask_full"] = batch["fut_mask"].float().to(device) # B x N x F + data["fut_mask"] = data["fut_mask_full"][..., : self.num_frames_per_stage] + data["agent_avail"] = data["pre_mask"].any(-1).float() + data["image"] = batch["image"] + if cond_traj is not None and self.ego_conditioning: + Ne = cond_traj.shape[1] + for k in [ + "heading", + "pre_motion_raw", + "fut_motion_full", + "pre_mask", + "fut_mask_full", + "fut_mask", + "agent_avail", + "image", + "pre_heading_raw", + "fut_heading_raw", + "fut_heading_full", + ]: + data[k] = ( + data[k].repeat_interleave(Ne, 0) if data[k] is not None else None + ) + + fut_motion_full = data["fut_motion_full"] * traj_scale + cond_traj_tiled = TensorUtils.join_dimensions(cond_traj, 0, 2) + fut_motion_full[:, 0] = cond_traj_tiled[..., :2] + data["fut_heading_full"][:, 0] = cond_traj_tiled[..., 2] + data["fut_heading_raw"] = data["fut_heading_full"][ + ..., : self.num_frames_per_stage + ] + data["fut_motion_full"] = fut_motion_full / traj_scale + data["fut_motion_raw"] = data["fut_motion_full"][ + :, :, : self.num_frames_per_stage + ] + + if self.ego_conditioning: + data["cond_traj"] = cond_traj + else: + data["cond_traj"] = None + data["pre_vel"] = self.DoubleIntegrator.calculate_vel( + data["pre_motion_raw"], None, data["pre_mask"].bool() + ) + data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() + data["fut_vel"] = self.DoubleIntegrator.calculate_vel( + data["fut_motion_raw"], None, data["fut_mask"].bool() + ) # F x N x B x 2 + data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() + + return data + + def gen_data_stage(self, batch, pred_traj, stage): + if stage == 0: + return batch + else: + data = defaultdict(lambda: None) + device = pred_traj.device + # fields that does not change + for k in [ + "traj_scale", + "agent_enc_shuffle", + "fut_motion_full", + "fut_mask_full", + ]: + data[k] = batch[k] + traj_scale = self.cfg.traj_scale + bs, M, Na = pred_traj.shape[:3] + data["heading"] = batch["heading"].repeat_interleave(M, 0) # (B*M) x N + + Ts = self.num_frames_per_stage + P = self.cfg.history_num_frames + F = self.cfg.future_num_frames + if Ts < P: + # left over from previous stage + + prev_stage_hist_pos = batch["pre_motion_raw"][ + :, :, Ts - P : + ].repeat_interleave( + M, 0 + ) # (B*M) x N x (P-Ts) x 2 + prev_stage_hist_yaw = batch["pre_heading_raw"][ + :, :, Ts - P : + ].repeat_interleave( + M, 0 + ) # (B*M) x N x (P-Ts) + new_hist_pos = TensorUtils.join_dimensions( + pred_traj[:, :, :, :Ts, :2], 0, 2 + ) # (B*M) x N x Ts x 2 + new_hist_yaw = TensorUtils.join_dimensions( + pred_traj[:, :, :, :Ts, 2], 0, 2 + ) # (B*M) x N x Ts + + data["pre_motion_raw"] = torch.cat( + (prev_stage_hist_pos, new_hist_pos), 2 + ) # (B*M) x N x P x 2 + data["pre_heading_raw"] = torch.cat( + (prev_stage_hist_yaw, new_hist_yaw), 2 + ) # (B*M) x N x P + + prev_stage_pre_mask = batch["pre_mask"][ + :, :, Ts - P : + ].repeat_interleave( + M, 0 + ) # (B*M) x N x (P-Ts) + # since this is associated with the predicted trajectory, all entries is True except for dummy agents + new_stage_pre_mask = ( + batch["agent_avail"] + .unsqueeze(-1) + .repeat_interleave(M, 0) + .repeat_interleave(Ts, -1) + ) # (B*M) x N x Ts + data["pre_mask"] = torch.cat( + (prev_stage_pre_mask, new_stage_pre_mask), -1 + ) # (B*M) x N x P + else: + data["pre_motion_raw"] = TensorUtils.join_dimensions( + pred_traj[:, :, :, -P:, :2], 0, 2 + ) # (B*M) x N x P x 2 + data["pre_heading_raw"] = TensorUtils.join_dimensions( + pred_traj[:, :, :, -P:], 0, 2 + ) # (B*M) x N x P + data["pre_mask"] = ( + batch["agent_avail"] + .unsqueeze(-1) + .repeat_interleave(M, 0) + .repeat_interleave(P, -1) + ) # (B*M) x N x P + # for future motion, pad the unknown future with 0 + + data["fut_motion_raw"] = batch["fut_motion_full"][ + ..., stage * Ts : (stage + 1) * Ts, : + ].repeat_interleave( + M**stage, 0 + ) # (B*M) x N x Ts x 2 + data["fut_heading_raw"] = batch["fut_heading_full"][ + ..., stage * Ts : (stage + 1) * Ts + ].repeat_interleave( + M**stage, 0 + ) # (B*M) x N x Ts + + data["fut_mask"] = batch["fut_mask_full"][ + ..., stage * Ts : (stage + 1) * Ts + ].repeat_interleave(M**stage, 0) + + data["agent_avail"] = batch["agent_avail"].repeat_interleave(M, 0) + + data["pre_vel"] = self.DoubleIntegrator.calculate_vel( + data["pre_motion_raw"], None, data["pre_mask"].bool() + ) + data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() + data["fut_vel"] = self.DoubleIntegrator.calculate_vel( + data["fut_motion_raw"], None, data["fut_mask"].bool() + ) # F x N x B x 2 + data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() + if data["map_enc"] is not None: + data["map_enc"] = batch["map_enc"].repeat_interleave( + M, 1 + ) # N x (B*M) x D + + return data + + def forward(self, batch, sample_k=None, predict=False, **kwargs): + cond_traj = kwargs["cond_traj"] if "cond_traj" in kwargs else None + data0 = self.convert_data(batch, cond_traj=cond_traj, predict=predict) + pred_traj = None + pred_batch = dict() + pred_batch["p_z_dist"] = dict() + pred_batch["q_z_dist"] = dict() + if self.ego_conditioning: + cond_idx = [0] + else: + cond_idx = None + data_stage = data0 + for stage in range(self.stage): + data_stage = self.gen_data_stage(data_stage, pred_traj, stage) + self.set_data(data_stage, stage) + pred_data = self.run_model( + stage, sample_k, predict=predict, cond_idx=cond_idx + ) + pred_traj = pred_data["infer_dec_motion"] + if "infer_dec_state" not in pred_data: + yaws = torch.zeros_like(pred_traj[..., 0:1]) + else: + yaws = pred_data["infer_dec_state"][..., 3:] + + pred_traj = torch.cat((pred_traj, yaws), -1) + pred_batch["p_z_dist"][stage] = pred_data["p_z_dist"] + pred_batch["q_z_dist"][stage] = pred_data["q_z_dist"] + + positions, state, var, controls = self.batching_multistage_traj() + positions = positions * self.cfg.traj_scale + NeB, numMode, Na, F = positions.shape[:4] + bs = batch["hist_pos"].shape[0] + Ne = int(NeB / bs) + if state is None: + yaws = batch["hist_yaw"][:, :, [-1]].repeat_interleave(F, 2) + + yaws = ( + yaws.unsqueeze(1).repeat_interleave(Ne, 0).repeat_interleave(numMode, 1) + ) + trajectories = torch.cat((positions, yaws), -1) + else: + trajectories = state[..., [0, 1, 3]] + if "prob" not in self.data[0]: + prob = ( + torch.ones(trajectories.shape[:2]).to(trajectories.device) + / trajectories.shape[1] + ) + prob = prob / prob.sum(-1, keepdim=True) + else: + M = int(numMode ** (1 / self.stage)) + prob = self.data[self.stage - 1]["prob"].reshape( + bs * Ne, *([M] * self.stage) + ) + + for stage in range(self.stage - 1): + desired_shape = ( + [bs * Ne] + [M] * (stage + 1) + [1] * (self.stage - stage - 1) + ) + prob = prob * TensorUtils.reshape_dimensions( + self.data[stage]["prob"], 0, 2, desired_shape + ) + prob = TensorUtils.join_dimensions(prob, 1, self.stage + 1) + + pred_except_dist = dict( + trajectories=trajectories, + state_trajectory=state, + p=prob, + fut_pos=data0["fut_motion_full"] * self.cfg.traj_scale, + ) + pred_except_dist = TensorUtils.reshape_dimensions( + pred_except_dist, 0, 1, [bs, Ne] + ) + pred_batch.update(pred_except_dist) + if controls is not None: + pred_batch["controls"] = controls + pred_batch["cond_traj"] = data0["cond_traj"] + agent_avail = self.data[0]["agent_avail"] + agent_avail = agent_avail.reshape([bs, Ne, -1])[:, 0] + pred_batch["agent_avail"] = agent_avail + pred_batch.update(self._traj_to_preds(pred_batch["trajectories"])) + if var is not None: + pred_batch["var"] = var + if not predict: + self.step_annealer() + else: + pred_batch = {k: v for k, v in pred_batch.items() if "dist" not in k} + pred_batch["data_batch"] = batch + del data0 + return pred_batch + + def batching_multistage_traj(self): + if "infer_dec_motion" in self.data[0]: + infer_traj = list() + bs, M = self.data[0]["infer_dec_motion"].shape[:2] + for stage in range(self.stage): + traj_i = self.data[stage]["infer_dec_motion"].repeat_interleave( + (M ** (self.stage - stage - 1)), 0 + ) + traj_i = traj_i.reshape(bs, M**self.stage, *traj_i.shape[2:]) + infer_traj.append(traj_i) + infer_traj = torch.cat(infer_traj, -2) + else: + infer_traj = None + if "infer_dec_state" in self.data[0]: + infer_state = list() + bs, M = self.data[0]["infer_dec_state"].shape[:2] + for stage in range(self.stage): + state_i = self.data[stage]["infer_dec_state"].repeat_interleave( + (M ** (self.stage - stage - 1)), 0 + ) + state_i = state_i.reshape(bs, M**self.stage, *state_i.shape[2:]) + infer_state.append(state_i) + infer_state = torch.cat(infer_state, -2) + else: + infer_state = None + if "infer_var" in self.data[0]: + infer_var = list() + bs, M = self.data[0]["infer_var"].shape[:2] + for stage in range(self.stage): + var_i = self.data[stage]["infer_var"].repeat_interleave( + (M ** (self.stage - stage - 1)), 0 + ) + var_i = var_i.reshape(bs, M**self.stage, *var_i.shape[2:]) + infer_var.append(var_i) + infer_var = torch.cat(infer_var, -2) + else: + infer_var = None + + if "controls" in self.data[0]: + controls = list() + bs, M = self.data[0]["controls"].shape[:2] + for stage in range(self.stage): + controls_i = self.data[stage]["controls"].repeat_interleave( + (M ** (self.stage - stage - 1)), 0 + ) + controls_i = controls_i.reshape( + bs, M**self.stage, *controls_i.shape[2:] + ) + controls.append(controls_i) + controls = torch.cat(controls, -2) + else: + controls = None + + return infer_traj, infer_state, infer_var, controls + + def sample(self, batch, sample_k): + return self.forward(batch, sample_k) + + def run_model(self, stage, sample_k=None, predict=False, cond_idx=None): + if self.use_map and self.data[stage]["map_enc"] is None: + image = self.data[0]["image"] + bs, Na = image.shape[:2] + map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) + map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) + self.data[stage]["map_enc"] = map_enc.transpose(0, 1) + self.context_encoder(self.data[stage]) + if not predict: + self.future_encoder(self.data[stage]) + # self.future_decoder(self.data[stage], mode='train', autoregress=self.ar_train) + + if sample_k is None: + self.inference( + sample_num=self.cfg.sample_k, + stage=stage, + cond_idx=cond_idx, + predict=predict, + ) + else: + self.inference( + sample_num=sample_k, stage=stage, cond_idx=cond_idx, predict=predict + ) + + # self.data[stage]["cond_traj"] = None + return self.data[stage] + + def inference( + self, + mode="infer", + sample_num=20, + need_weights=False, + stage=0, + cond_idx=None, + predict=False, + ): + if self.use_map and self.data[stage]["map_enc"] is None: + image = self.data[0]["image"] + bs, Na = image.shape[:2] + map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) + map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) + self.data[stage]["map_enc"] = map_enc.transpose(0, 1) + if self.data[stage]["context_enc"] is None: + self.context_encoder(self.data[stage]) + if mode == "recon": + sample_num = 1 + self.future_encoder(self.data[stage], temp=self.z_tau_annealer.val()) + + if self.z_type == "gaussian": + temp = None + else: + temp = 0.0001 if predict else self.z_tau_annealer.val() + # raise Exception("one of p and q need to exist") + + self.future_decoder( + self.data[stage], + mode=mode, + sample_num=sample_num, + autoregress=True, + need_weights=need_weights, + cond_idx=cond_idx, + temp=temp, + predict=predict, + ) + return self.data[stage][f"{mode}_dec_motion"], self.data + + def _traj_to_preds(self, traj): + pred_positions = traj[..., :2] + pred_yaws = traj[..., 2:3] + return { + "trajectories": traj, + "predictions": {"positions": pred_positions, "yaws": pred_yaws}, + } + + def compute_losses(self, pred_batch, data_batch): + if "data_batch" in pred_batch: + data_batch = pred_batch["data_batch"] + device = pred_batch["trajectories"].device + bs, Ne, numMode, Na = pred_batch["trajectories"].shape[:4] + pred_batch["trajectories"] = pred_batch["trajectories"].nan_to_num(0) + M = int(numMode ** (1 / self.stage)) + kl_loss = torch.tensor(0.0, device=device) + if "q_z_dist" in pred_batch and "p_z_dist" in pred_batch: + for stage in range(self.stage): + kl_loss += ( + pred_batch["q_z_dist"][stage] + .kl(pred_batch["p_z_dist"][stage]) + .nan_to_num(0) + .sum(-1) + .mean() + ) + + kl_loss = kl_loss.clamp_min_(self.cfg.loss_cfg.kld.min_clip) + + traj_pred_tiled = TensorUtils.join_dimensions(pred_batch["trajectories"], 0, 2) + traj_pred_tiled2 = TensorUtils.join_dimensions(traj_pred_tiled, 0, 2) + cond_traj = pred_batch["cond_traj"] + if Ne > 1: + fut_mask = data_batch["fut_mask"].repeat_interleave(Ne, 0) + else: + fut_mask = data_batch["fut_mask"] + + pred_loss, goal_loss = MultiModal_trajectory_loss( + predictions=traj_pred_tiled[..., :2], + targets=TensorUtils.join_dimensions(pred_batch["fut_pos"], 0, 2), + availabilities=fut_mask, + prob=TensorUtils.join_dimensions(pred_batch["p"], 0, 2), + calc_goal_reach=False, + ) + extent = data_batch["extent"][..., :2] + div_score = diversity_score( + traj_pred_tiled[..., :2], + fut_mask.unsqueeze(1).repeat_interleave(numMode, 1).any(1), + ) + # cond_extent = extent[torch.arange(bs),pred_batch["cond_idx"]] + if pred_batch["cond_traj"] is not None: + if "EC_coll_loss" in pred_batch: + EC_coll_loss = pred_batch["EC_coll_loss"] + else: + EC_edges, type_mask = batch_utils().gen_EC_edges( + traj_pred_tiled2[:, 1:], + cond_traj.reshape(bs * Ne, 1, -1, 3) + .repeat_interleave(numMode, 0) + .repeat_interleave(Na - 1, 1), + extent[:, 0].repeat_interleave(Ne * numMode, 0), + extent[:, 1:].repeat_interleave(Ne * numMode, 0), + data_batch["type"][:, 1:].repeat_interleave(Ne * numMode, 0), + pred_batch["agent_avail"].repeat_interleave(Ne * numMode, 0)[:, 1:], + ) + + EC_edges = TensorUtils.reshape_dimensions( + EC_edges, 0, 1, (bs, Ne, numMode) + ) + type_mask = TensorUtils.reshape_dimensions( + type_mask, 0, 1, (bs, Ne, numMode) + ) + prob = pred_batch["p"] + EC_coll_loss = collision_loss_masked( + EC_edges, type_mask, weight=prob.reshape(bs, Ne, -1).unsqueeze(-1) + ) + if not isinstance(EC_coll_loss, torch.Tensor): + EC_coll_loss = torch.tensor(EC_coll_loss).to(device) + else: + EC_coll_loss = torch.tensor(0.0).to(device) + + # compute collision loss + + pred_edges = batch_utils().generate_edges( + pred_batch["agent_avail"].repeat_interleave(numMode * Ne, 0), + extent.repeat_interleave(Ne * numMode, 0), + traj_pred_tiled2[..., :2], + traj_pred_tiled2[..., 2:], + ) + + coll_loss = collision_loss(pred_edges=pred_edges) + if not isinstance(coll_loss, torch.Tensor): + coll_loss = torch.tensor(coll_loss).to(device) + + losses = OrderedDict( + prediction_loss=pred_loss, + kl_loss=kl_loss, + collision_loss=coll_loss, + EC_collision_loss=EC_coll_loss, + diversity_loss=-div_score, + ) + + if "controls" in pred_batch: + acce_reg_loss = (pred_batch["controls"][..., 0] ** 2).mean() + steering_reg_loss = (pred_batch["controls"][..., 1] ** 2).mean() + losses["acce_reg_loss"] = acce_reg_loss + losses["steering_reg_loss"] = steering_reg_loss + + # if self.cfg.input_weight_scaling is not None and "controls" in pred_batch: + # input_weight_scaling = torch.tensor(self.cfg.input_weight_scaling).to(pred_batch["controls"].device) + # losses["input_loss"] = torch.mean(pred_batch["controls"] ** 2 *pred_batch["mask"][...,None]*input_weight_scaling) + + return losses + + +class ARAgentFormer(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + input_type = cfg.input_type + pred_type = cfg.pred_type + if type(input_type) == str: + input_type = [input_type] + fut_input_type = cfg.fut_input_type + dec_input_type = cfg.dec_input_type + + self.use_map = cfg.use_map + self.rand_rot_scene = cfg.rand_rot_scene + self.discrete_rot = cfg.discrete_rot + self.map_global_rot = cfg.map_global_rot + self.ar_train = cfg.ar_train + self.max_train_agent = cfg.max_train_agent + self.loss_cfg = cfg.loss_cfg + self.param_annealers = nn.ModuleList() + self.z_type = cfg.z_type + if self.z_type == "discrete": + z_tau_annealer = ExpParamAnnealer( + cfg.z_tau.start, cfg.z_tau.finish, cfg.z_tau.decay + ) + self.param_annealers.append(z_tau_annealer) + self.z_tau_annealer = z_tau_annealer + # if "gt_step_anneal_length" in cfg and cfg.gt_step_anneal_length>0: + # self.gt_step_annealer = IntegerParamAnnealer(cfg.future_num_frames-1,0,cfg.gt_step_anneal_length) + # self.param_annealers.append(self.gt_step_annealer) + self.step_time = cfg.step_time + self.dyn = dynamics.Unicycle(cfg.step_time) + self.DoubleIntegrator = dynamics.DoubleIntegrator(cfg.step_time) + + # save all computed variables + self.data = dict() + + # map encoder + if self.use_map: + self.map_encoder = base_models.RasterizedMapEncoder( + model_arch=cfg.map_encoder.model_architecture, + input_image_shape=cfg.map_encoder.image_shape, + feature_dim=cfg.map_encoder.feature_dim, + use_spatial_softmax=cfg.map_encoder.spatial_softmax.enabled, + spatial_softmax_kwargs=cfg.map_encoder.spatial_softmax.kwargs, + ) + + # models + self.context_encoder = ContextEncoder(cfg) + self.future_encoder = FutureEncoder(cfg) + self.future_decoder = FutureARDecoder(cfg) + + self.future_num_frames = cfg.future_num_frames + self.history_num_frames = cfg.history_num_frames + + def set_data(self, batch): + device = batch["pre_motion_raw"].device + self.data = batch + self.data["step_time"] = self.step_time + bs, Na = batch["pre_motion_raw"].shape[:2] + self.data["pre_motion"] = ( + batch["pre_motion_raw"].to(device).transpose(0, 2).contiguous() + ) # P x N x B x 2 + self.data["fut_motion"] = ( + batch["fut_motion_raw"].to(device).transpose(0, 2).contiguous() + ) # F x N x B x 2 + + # compute the origin of the current scene, i.e., the center + # of the agents' location in the current frame + self.data["scene_orig"] = torch.nanmean( + self.data["pre_motion"][-1], dim=0 + ) # B x 2 + + # normalize the scene with respect to the center location + # optionally, also rotate the scene for augmentation + if self.rand_rot_scene and self.training: + # below cannot be fixed in seed, causing reproducibility issue + if self.discrete_rot: + theta = torch.randint(high=24, size=(1,)).to(device) * (np.pi / 12) + else: + theta = torch.rand(1).to(device) * np.pi * 2 # [0, 2*pi], full circle + + for key in ["pre_motion", "fut_motion"]: + ( + self.data[f"{key}"], + self.data[f"{key}_scene_norm"], + ) = rotation_2d_torch(self.data[key], theta, self.data["scene_orig"]) + if self.data["heading"] is not None: + self.data["heading"] += theta # B x N + else: + theta = torch.zeros(1).to(device) + + # normalize per scene + for key in ["pre_motion", "fut_motion"]: # (F or P) x N x B x 2 + self.data[f"{key}_scene_norm"] = ( + self.data[key] - self.data["scene_orig"] + ) + + # normalize pos per agent + self.data["cur_motion"] = self.data["pre_motion"][[-1]] # 1 x N x B x 2 + self.data["pre_motion_norm"] = ( + self.data["pre_motion"][:-1] - self.data["cur_motion"] # P x N x B x 2 + ) + self.data["fut_motion_norm"] = ( + self.data["fut_motion"] - self.data["cur_motion"] + ) # F x N x B x 2 + + # vectorize heading + if self.data["heading"] is not None: + self.data["heading_vec"] = torch.stack( + [torch.cos(self.data["heading"]), torch.sin(self.data["heading"])], + dim=-1, + ).transpose(0, 1) + # N x B x 2 + self.data["pre_heading_vec"] = torch.stack( + [ + torch.cos(self.data["pre_heading_raw"]), + torch.sin(self.data["pre_heading_raw"]), + ], + dim=-1, + ).transpose(0, 2) + # P x N x B x 2 + + self.data["fut_heading_vec"] = torch.stack( + [ + torch.cos(self.data["fut_heading_raw"]), + torch.sin(self.data["fut_heading_raw"]), + ], + dim=-1, + ).transpose(0, 2) + # F x N x B x 2 + + # agent shuffling, default not shuffling + if self.training and self.cfg["agent_enc_shuffle"]: + self.data["agent_enc_shuffle"] = torch.randperm(self.cfg["max_agent_len"])[ + : self.data["agent_num"] + ].to(device) + else: + self.data["agent_enc_shuffle"] = None + + # mask between pairwse agents, such as diable connection for a pair of agents + # that are far away from each other, currently not used, i.e., assuming all connections + conn_dist = self.cfg.conn_dist + cur_motion = self.data["cur_motion"][0] + if conn_dist < 1000.0: + threshold = conn_dist / self.cfg.traj_scale + pdist = F.pdist(cur_motion) + D = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) + D[np.triu_indices(cur_motion.shape[0], 1)] = pdist + D += D.T + mask = torch.zeros_like(D) + mask[D > threshold] = float("-inf") + else: + mask = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) + self.data["agent_mask"] = mask # N x N, all zeros now, i.e., fully-connected + + def step_annealer(self): + for anl in self.param_annealers: + anl.step() + + def convert_data(self, batch): + data = defaultdict(lambda: None) + + device = batch["hist_pos"].device + bs = batch["hist_yaw"].shape[0] + data["heading"] = batch["hist_yaw"][:, :, -1, 0].to(device) # B x N + data["pre_heading_raw"] = batch["hist_yaw"][..., 0].to(device) # B x N x P + data["fut_heading_full"] = batch["fut_yaw"][..., 0].to(device) + data["fut_heading_raw"] = data["fut_heading_full"][ + ..., : self.future_num_frames + ] # B x N x F + traj_scale = self.cfg.traj_scale + data["traj_scale"] = traj_scale + # AgentFormer uses the x/y inputs, i.e., the first two dimensions + data["pre_motion_raw"] = (batch["hist_pos"] / traj_scale).to( + device + ) # B x N x P x 2 + data["fut_motion_full"] = (batch["fut_pos"] / traj_scale).to(device) + data["fut_motion_raw"] = ( + batch["fut_pos"][:, :, : self.future_num_frames] / traj_scale + ).to( + device + ) # B x N x F x 2 + + data["pre_mask"] = ( + batch["hist_mask"].float().to(device) + ) # B x N x P # B x N x F x 2 + data["fut_mask_full"] = batch["fut_mask"].float().to(device) # B x N x F + data["fut_mask"] = data["fut_mask_full"][..., : self.future_num_frames] + data["agent_avail"] = data["pre_mask"].any(-1).float() + data["image"] = batch["image"] + + data["pre_vel"] = self.DoubleIntegrator.calculate_vel( + data["pre_motion_raw"], None, data["pre_mask"].bool() + ) + data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() + data["fut_vel"] = self.DoubleIntegrator.calculate_vel( + data["fut_motion_raw"], None, data["fut_mask"].bool() + ) # F x N x B x 2 + data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() + + return data + + def forward(self, batch, sample_k=None, predict=False, **kwargs): + data = self.convert_data(batch) + pred_batch = dict() + pred_batch["p_z_dist"] = dict() + pred_batch["q_z_dist"] = dict() + + self.set_data(data) + pred_data = self.run_model(sample_k, predict=predict) + + mode = "infer" if predict else "train" + if mode == "infer": + yaws = pred_data[f"{mode}_dec_state"][..., 3:] + + pred_batch["p_z_dist"] = pred_data["p_z_dist"] + pred_batch["q_z_dist"] = pred_data["q_z_dist"] + + positions = self.data[f"{mode}_dec_motion"] + state = self.data[f"{mode}_dec_state"] + controls = self.data["controls"] + positions = positions * self.cfg.traj_scale + bs, numMode, Na, F = positions.shape[:4] + if state is None: + yaws = batch["hist_yaw"][:, :, [-1]].repeat_interleave(F, 2) + + yaws = yaws.unsqueeze(1).repeat_interleave(numMode, 1) + trajectories = torch.cat((positions, yaws), -1) + else: + trajectories = state[..., [0, 1, 3]] + if "prob" not in self.data: + prob = ( + torch.ones(trajectories.shape[:2]).to(trajectories.device) + / trajectories.shape[1] + ) + prob = prob / prob.sum(-1, keepdim=True) + else: + prob = self.data["prob"].reshape(bs, -1) + + pred_except_dist = dict( + trajectories=trajectories, + state_trajectory=state, + p=prob, + fut_pos=self.data["fut_motion_full"] * self.cfg.traj_scale, + ) + pred_batch.update(pred_except_dist) + if controls is not None: + pred_batch["controls"] = controls + + pred_batch["agent_avail"] = self.data["agent_avail"] + pred_batch.update(self._traj_to_preds(pred_batch["trajectories"])) + pred_batch = {k: v for k, v in pred_batch.items() if "dist" not in k} + pred_batch["data_batch"] = batch + return pred_batch + else: + self.step_annealer() + return self.data + + def sample(self, batch, sample_k): + return self.forward(batch, sample_k) + + def run_model(self, sample_k=None, predict=False): + if self.use_map and self.data["map_enc"] is None: + image = self.data["image"] + bs, Na = image.shape[:2] + map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) + map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) + self.data["map_enc"] = map_enc.transpose(0, 1) + self.context_encoder(self.data) + mode = "infer" if predict else "train" + if mode == "infer": + if sample_k is None: + self.inference(sample_num=self.cfg.sample_k, predict=predict, mode=mode) + else: + self.inference(sample_num=sample_k, predict=predict, mode=mode) + else: + self.train_model(mode=mode) + return self.data + + def inference(self, mode="infer", sample_num=20, need_weights=False, predict=False): + if self.use_map and self.data["map_enc"] is None: + image = self.data["image"] + bs, Na = image.shape[:2] + map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) + map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) + self.data["map_enc"] = map_enc.transpose(0, 1) + if self.data["context_enc"] is None: + self.context_encoder(self.data) + + if self.z_type == "gaussian": + temp = None + else: + temp = 0.0001 if predict else self.z_tau_annealer.val() + # raise Exception("one of p and q need to exist") + + self.future_decoder( + self.data, + mode=mode, + sample_num=sample_num, + autoregress=True, + need_weights=need_weights, + temp=temp, + predict=predict, + gt_step=0, + ) + return self.data[f"{mode}_dec_motion"], self.data + + def train_model(self, mode="train", need_weights=False): + if self.use_map and self.data["map_enc"] is None: + image = self.data["image"] + bs, Na = image.shape[:2] + map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) + map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) + self.data["map_enc"] = map_enc.transpose(0, 1) + if self.data["context_enc"] is None: + self.context_encoder(self.data) + + if self.z_type == "discrete": + temp = self.z_tau_annealer.val() + else: + temp = None + # raise Exception("one of p and q need to exist") + + self.future_decoder( + self.data, + mode=mode, + sample_num=1, + autoregress=True, + need_weights=need_weights, + temp=temp, + ) + + return self.data + + def _traj_to_preds(self, traj): + pred_positions = traj[..., :2] + pred_yaws = traj[..., 2:3] + return { + "trajectories": traj, + "predictions": {"positions": pred_positions, "yaws": pred_yaws}, + } + + def compute_losses(self, pred_batch, data_batch): + losses = dict() + if "log_prob" in pred_batch: + log_prob_loss = -pred_batch["log_prob"].mean() + losses["log_prob"] = log_prob_loss + if "q_z_dist" in pred_batch and pred_batch["q_z_dist"] is not None: + kl_loss = ( + pred_batch["q_z_dist"] + .kl(pred_batch["p_z_dist"]) + .nan_to_num(0) + .sum(-1) + .mean() + ) + losses["kl_loss"] = kl_loss + + if "trajectories" in pred_batch: + # inference mode + if "data_batch" in pred_batch: + data_batch = pred_batch["data_batch"] + device = pred_batch["trajectories"].device + bs, numMode, Na = pred_batch["trajectories"].shape[:3] + pred_batch["trajectories"] = pred_batch["trajectories"].nan_to_num(0) + + traj_pred = pred_batch["trajectories"] + traj_pred_tiled2 = TensorUtils.join_dimensions(traj_pred, 0, 2) + + fut_mask = data_batch["fut_mask"] + + pred_loss, goal_loss = MultiModal_trajectory_loss( + predictions=traj_pred[..., :2], + targets=pred_batch["fut_pos"], + availabilities=fut_mask, + prob=pred_batch["p"], + calc_goal_reach=False, + ) + extent = data_batch["extent"][..., :2] + div_score = diversity_score( + traj_pred[..., :2], + fut_mask.unsqueeze(1).repeat_interleave(numMode, 1).any(1), + ) + + # compute collision loss + + pred_edges = batch_utils().generate_edges( + pred_batch["agent_avail"].repeat_interleave(numMode, 0), + extent.repeat_interleave(numMode, 0), + traj_pred_tiled2[..., :2], + traj_pred_tiled2[..., 2:], + ) + + coll_loss = collision_loss(pred_edges=pred_edges) + if not isinstance(coll_loss, torch.Tensor): + coll_loss = torch.tensor(coll_loss).to(device) + + pred_losses = OrderedDict( + prediction_loss=pred_loss, + collision_loss=coll_loss, + diversity_loss=-div_score, + ) + # if "controls" in pred_batch: + # scale = self.cfg.loss_weights.input_loss_scale if "input_loss_scale" in self.cfg.loss_weights else 1.0 + # acce_reg_loss = (pred_batch["controls"][...,0]**2).mean() + # steering_reg_loss = (pred_batch["controls"][...,1]**2).mean() + # pred_losses["acce_reg_loss"] = acce_reg_loss*scale + # pred_losses["steering_reg_loss"] = steering_reg_loss*scale + # pred_losses["acce_jerk_loss"] = torch.mean((pred_batch["controls"][...,1:,0]-pred_batch["controls"][...,:-1,0])**2)*scale + # pred_losses["steering_jerk_loss"] = torch.mean((pred_batch["controls"][...,1:,1]-pred_batch["controls"][...,:-1,1])**2)*scale + losses.update(pred_losses) + + # if self.cfg.input_weight_scaling is not None and "controls" in pred_batch: + # input_weight_scaling = torch.tensor(self.cfg.input_weight_scaling).to(pred_batch["controls"].device) + # losses["input_loss"] = torch.mean(pred_batch["controls"] ** 2 *pred_batch["mask"][...,None]*input_weight_scaling) + + return losses diff --git a/diffstack/models/agentformer_lib.py b/diffstack/models/agentformer_lib.py new file mode 100644 index 0000000..2e7eebb --- /dev/null +++ b/diffstack/models/agentformer_lib.py @@ -0,0 +1,1044 @@ +""" +Modified version of PyTorch Transformer module for the implementation of Agent-Aware Attention (L290-L308) +""" + +from typing import Callable, Dict, Final, List, Optional, Set, Tuple, Union +import warnings +import math +import copy + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.functional import * +from torch.nn.modules.module import Module +from torch.nn.modules.activation import MultiheadAttention +from torch.nn.modules.container import ModuleList +from torch.nn.init import xavier_uniform_ +from torch.nn.modules.dropout import Dropout +from torch.nn.modules.linear import Linear +from torch.nn.modules.normalization import LayerNorm +from torch.nn.init import xavier_uniform_ +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.parameter import Parameter +from torch.overrides import has_torch_function, handle_torch_function +from torchvision import models + +def compute_z_kld(data, cfg): + loss_unweighted = data['q_z_dist_dlow'].kl(data['p_z_dist_infer']).sum() + if cfg.get('normalize', True): + loss_unweighted /= data['batch_size'] + loss_unweighted = loss_unweighted.clamp_min_(cfg.min_clip) + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + +def diversity_loss(data, cfg): + loss_unweighted = 0 + fut_motions = data['infer_dec_motion'].view(*data['infer_dec_motion'].shape[:2], -1) + for motion in fut_motions: + dist = F.pdist(motion, 2) ** 2 + loss_unweighted += (-dist / cfg['d_scale']).exp().mean() + if cfg.get('normalize', True): + loss_unweighted /= data['batch_size'] + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + +def recon_loss(data, cfg): + diff = data['infer_dec_motion'] - data['fut_motion_orig'].unsqueeze(1) + if cfg.get('mask', True): + mask = data['fut_mask'].unsqueeze(1).unsqueeze(-1) + diff *= mask + dist = diff.pow(2).sum(dim=-1).sum(dim=-1) + loss_unweighted = dist.min(dim=1)[0] + if cfg.get('normalize', True): + loss_unweighted = loss_unweighted.mean() + else: + loss_unweighted = loss_unweighted.sum() + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + + +# """ DLow (Diversifying Latent Flows)""" +# class DLow(nn.Module): +# def __init__(self, cfg): +# super().__init__() + +# self.device = torch.device('cpu') +# self.cfg = cfg +# self.nk = nk = cfg.sample_k +# self.nz = nz = cfg.nz +# self.share_eps = cfg.get('share_eps', True) +# self.train_w_mean = cfg.get('train_w_mean', False) +# self.loss_cfg = self.cfg.loss_cfg +# self.loss_names = list(self.loss_cfg.keys()) + +# pred_cfg = Config(cfg.pred_cfg, cfg.train_tag, tmp=False, create_dirs=False) +# pred_model = model_lib.model_dict[pred_cfg.model_id](pred_cfg) +# self.pred_model_dim = pred_cfg.tf_model_dim +# if cfg.pred_epoch > 0: +# cp_path = pred_cfg.model_path % cfg.pred_epoch +# print('loading model from checkpoint: %s' % cp_path) +# model_cp = torch.load(cp_path, map_location='cpu') +# pred_model.load_state_dict(model_cp['model_dict']) +# pred_model.eval() +# self.pred_model = [pred_model] + +# # Dlow's Q net +# self.qnet_mlp = cfg.get('qnet_mlp', [512, 256]) +# self.q_mlp = MLP(self.pred_model_dim, self.qnet_mlp) +# self.q_A = nn.Linear(self.q_mlp.out_dim, nk * nz) +# self.q_b = nn.Linear(self.q_mlp.out_dim, nk * nz) + +# def set_device(self, device): +# self.device = device +# self.to(device) +# self.pred_model[0].set_device(device) + +# def set_data(self, data): +# self.pred_model[0].set_data(data) +# self.data = self.pred_model[0].data + +# def main(self, mean=False, need_weights=False): +# pred_model = self.pred_model[0] +# if hasattr(pred_model, 'use_map') and pred_model.use_map: +# self.data['map_enc'] = pred_model.map_encoder(self.data['agent_maps']) +# pred_model.context_encoder(self.data) + +# if not mean: +# if self.share_eps: +# eps = torch.randn([1, self.nz]).to(self.device) +# eps = eps.repeat((self.data['agent_num'] * self.nk, 1)) +# else: +# eps = torch.randn([self.data['agent_num'], self.nz]).to(self.device) +# eps = eps.repeat_interleave(self.nk, dim=0) + +# qnet_h = self.q_mlp(self.data['agent_context']) +# A = self.q_A(qnet_h).view(-1, self.nz) +# b = self.q_b(qnet_h).view(-1, self.nz) + +# z = b if mean else A*eps + b +# logvar = (A ** 2 + 1e-8).log() +# self.data['q_z_dist_dlow'] = Normal(mu=b, logvar=logvar) + +# pred_model.future_decoder(self.data, mode='infer', sample_num=self.nk, autoregress=True, z=z, need_weights=need_weights) +# return self.data + +# def forward(self): +# return self.main(mean=self.train_w_mean) + +# def inference(self, mode, sample_num, need_weights=False): +# self.main(mean=True, need_weights=need_weights) +# res = self.data[f'infer_dec_motion'] +# if mode == 'recon': +# res = res[:, 0] +# return res, self.data + +# def compute_loss(self): +# total_loss = 0 +# loss_dict = {} +# loss_unweighted_dict = {} +# for loss_name in self.loss_names: +# loss, loss_unweighted = loss_func[loss_name](self.data, self.loss_cfg[loss_name]) +# total_loss += loss +# loss_dict[loss_name] = loss.item() +# loss_unweighted_dict[loss_name] = loss_unweighted.item() +# return total_loss, loss_dict, loss_unweighted_dict + +# def step_annealer(self): +# pass + +def agent_aware_attention(query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + gaussian_kernel = True, + num_agent = 1, + in_proj_weight_self = None, + in_proj_bias_self = None + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tgt_len, bs, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + # # replace nan with 0, so that we can use to check if it is a self-attention vs cross-attention + # query_no_nan = torch.nan_to_num(query.clone()) + # key_no_nan = torch.nan_to_num(key.clone()) + # value_no_nan = torch.nan_to_num(value.clone()) + + if not use_separate_proj_weight: + if torch.equal(query, key) and torch.equal(key, value): # PN x B X feat + # if torch.equal(query_no_nan, key_no_nan) and torch.equal(key_no_nan, value_no_nan): # PN x B X feat + + q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) # PN x B x feat + if in_proj_weight_self is not None: + q_self, k_self = linear(query, in_proj_weight_self, in_proj_bias_self).chunk(2, dim=-1) + + # elif torch.equal(key_no_nan, value_no_nan): + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = linear(key, _w, _b).chunk(2, dim=-1) + + if in_proj_weight_self is not None: + _w = in_proj_weight_self[:embed_dim, :] + _b = in_proj_bias_self[:embed_dim] + q_self = linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _w = in_proj_weight_self[embed_dim:, :] + _b = in_proj_bias_self[embed_dim:] + k_self = linear(key, _w, _b) + + else: + raise NotImplementedError + + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = linear(query, q_proj_weight_non_opt, in_proj_bias) + k = linear(key, k_proj_weight_non_opt, in_proj_bias) + v = linear(value, v_proj_weight_non_opt, in_proj_bias) + # k, q, v has PN x B X feat, q maybe FN x B x feat + + # default gaussian_kernel = False + if not gaussian_kernel: + q = q * scaling # remove scaling + if in_proj_weight_self is not None: + q_self = q_self * scaling # remove scaling + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bs * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bs, 1)]) + v = torch.cat([v, bias_v.repeat(1, bs, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bs * num_heads, head_dim).transpose(0, 1) # BH x PN x feat + if k is not None: + k = k.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) # BH x PN x feat + if v is not None: + v = v.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) + if in_proj_weight_self is not None: + q_self = q_self.contiguous().view(tgt_len, bs * num_heads, head_dim).transpose(0, 1) + k_self = k_self.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bs * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bs * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bs + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + if gaussian_kernel: + qk = torch.bmm(q, k.transpose(1, 2)) + q_n = q.pow(2).sum(dim=-1).unsqueeze(-1) + k_n = k.pow(2).sum(dim=-1).unsqueeze(1) + qk_dist = q_n + k_n - 2 * qk + attn_output_weights = qk_dist * scaling * 0.5 + else: + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # BH x PN x PN, or BH x FN x PN + # attention weights contain random numbers for the timestamps without data + + assert list(attn_output_weights.size()) == [bs * num_heads, tgt_len, src_len] + + if in_proj_weight_self is not None: + """ + ================================== + Agent-Aware Attention + ================================== + """ + attn_output_weights_inter = attn_output_weights # BH x PN x PN, or BH x FN x PN + attn_output_weights_self = torch.bmm(q_self, k_self.transpose(1, 2)) # BH x PN x PN, or BH x FN x PN + + # using identity matrix here since the agents are not shuffled + attn_weight_self_mask = torch.eye(num_agent).to(q.device) + attn_weight_self_mask = attn_weight_self_mask.repeat([attn_output_weights.shape[1] // num_agent, attn_output_weights.shape[2] // num_agent]).unsqueeze(0) + # 1 x PN x PN + + attn_output_weights = attn_output_weights_inter * (1 - attn_weight_self_mask) + attn_output_weights_self * attn_weight_self_mask # BH x PN x PN + + # masking the columns / rows + if attn_mask is not None: # BH x PN x PN or BH x FN x PN + if attn_mask.dtype == torch.bool: + + # assign -inf so that the columns of the invalid data will lead to 0 after softmax during attention + # this is to disable the interaction between valid agents in rows and invalid agents in some columns + attn_output_weights.masked_fill_(attn_mask, float('-inf')) # BH x PN x PN or BH x FN x PN + + # however, the rows with invalid data (with all -inf now) will lead to NaN after softmax + # because there is no single valid data for that row + # as a result, it will lead to NaN in backward during softmax + # we need to assign some dummy numbers to it, this process is needed for training + # but this does not affect the results in forward pass since these rows of features are not used + attn_output_weights.masked_fill_(attn_mask[:, :, [0]], 0.0) + else: + attn_output_weights += attn_mask # BH x PN x PN or BH x FN x PN + + attn_output_weights = softmax(attn_output_weights, dim=-1) # BH x PN x PN or BH x FN x PN + # the output attn_output_weights will have 0.17 for the rows with invalid data + # because the entire row prior to softmax is 0, so it is averaged over the row + + # to suppress the random number in the row with invalida data + # we mask again with 0s, however in-place operation not supported for backward + # attn_output_weights.masked_fill_(attn_mask[[0], :, [0]].unsqueeze(-1), 0.0) + else: + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bs, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bs * num_heads, tgt_len, src_len) + attn_output_weights = softmax( + attn_output_weights, dim=-1) + + # attn_output_weights is row-wise, i.e., the agent at a timestamp without valid data (NaN) has some random numbers + # for the columns, the agent at a timestamp without valid data is 0 + # add torch.nan_to_num to convert NaN to a large number for the agent value without valid data + # but when the 0 in attn_output_weights col * the large number in v, it will result in 0 + # in other words, we do not attend to the agent timestamp with NaN data + # the output might have some invalid rows (rows with random numbers but do not affect results) + attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) # BH x PN x PN, or BH x FN x PN + attn_output = torch.bmm(attn_output_weights, torch.nan_to_num(v, nan=1e+10)) # BH x PN x feat, or BH x FN x feat + + # to maintain elegancy, we mask those invalid rows with random numbers as 0s + # but not masking will not affect results + # final_mask = attn_mask[:, :, [0]] # 1 x PN x 1 + # attn_output = attn_output.masked_fill_(final_mask, 0.0) # BH x PN x feat + + # convert to output shape + assert list(attn_output.size()) == [bs * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bs, embed_dim) # PN x B x feat + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) # PN x B x feat + + # average attention weights over heads + if need_weights: + attn_output_weights = attn_output_weights.view(bs, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class AgentAwareAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, cfg, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super().__init__() + self.cfg = cfg + self.gaussian_kernel = self.cfg.get('gaussian_kernel', False) + self.sep_attn = self.cfg.get('sep_attn', True) + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + if self.sep_attn: + self.in_proj_weight_self = Parameter(torch.empty(2 * embed_dim, embed_dim)) + self.in_proj_bias_self = Parameter(torch.empty(2 * embed_dim)) + else: + self.in_proj_weight_self = self.in_proj_bias_self = None + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + if self.sep_attn: + xavier_uniform_(self.in_proj_weight_self) + constant_(self.in_proj_bias_self, 0.) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super().__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None, num_agent=1): + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return agent_aware_attention( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, gaussian_kernel=self.gaussian_kernel, + num_agent=num_agent, + in_proj_weight_self=self.in_proj_weight_self, + in_proj_bias_self=self.in_proj_bias_self + ) + else: + return agent_aware_attention( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, gaussian_kernel=self.gaussian_kernel, + num_agent=num_agent, + in_proj_weight_self=self.in_proj_weight_self, + in_proj_bias_self=self.in_proj_bias_self + ) + + +class AgentFormerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__(self, cfg, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + super().__init__() + self.cfg = cfg + self.self_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1) -> torch.Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src2 = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, num_agent=num_agent)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class AgentFormerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, cfg, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + super().__init__() + self.cfg = cfg + self.self_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) + self.multihead_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, num_agent = 1, need_weights = False) -> torch.Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, num_agent=num_agent, need_weights=need_weights) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2, cross_attn_weights = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, num_agent=num_agent, need_weights=need_weights) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt, self_attn_weights, cross_attn_weights + + +class AgentFormerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src: torch.Tensor, mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1) -> torch.Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for mod in self.layers: + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, num_agent=num_agent) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class AgentFormerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1, need_weights = False) -> torch.Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + + self_attn_weights = [None] * len(self.layers) + cross_attn_weights = [None] * len(self.layers) + for i, mod in enumerate(self.layers): + output, self_attn_weights[i], cross_attn_weights[i] = mod(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + num_agent=num_agent, need_weights=need_weights) + + if self.norm is not None: + output = self.norm(output) + + if need_weights: + self_attn_weights = torch.stack(self_attn_weights).cpu().numpy() + cross_attn_weights = torch.stack(cross_attn_weights).cpu().numpy() + + return output, {'self_attn_weights': self_attn_weights, 'cross_attn_weights': cross_attn_weights} + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + +class MapCNN(nn.Module): + def __init__(self, cfg): + super().__init__() + self.convs = nn.ModuleList() + map_channels = cfg.get('map_channels', 3) + patch_size = cfg.get('patch_size', [100, 100]) + hdim = cfg.get('hdim', [32, 32]) + kernels = cfg.get('kernels', [3, 3]) + strides = cfg.get('strides', [3, 3]) + self.out_dim = out_dim = cfg.get('out_dim', 32) + self.input_size = input_size = (map_channels, patch_size[0], patch_size[1]) + x_dummy = torch.randn(input_size).unsqueeze(0) + + for i, _ in enumerate(hdim): + self.convs.append(nn.Conv2d(map_channels if i == 0 else hdim[i-1], + hdim[i], kernels[i], + stride=strides[i])) + x_dummy = self.convs[i](x_dummy) + + self.fc = nn.Linear(x_dummy.numel(), out_dim) + + def forward(self, x): + for conv in self.convs: + x = F.leaky_relu(conv(x), 0.2) + x = torch.flatten(x, start_dim=1) + x = self.fc(x) + return x + +class MapEncoder(nn.Module): + def __init__(self, cfg): + super().__init__() + model_id = cfg.get('model_id', 'map_cnn') + dropout = cfg.get('dropout', 0.0) + self.normalize = cfg.get('normalize', True) + self.dropout = nn.Dropout(dropout) + if model_id == 'map_cnn': + self.model = MapCNN(cfg) + self.out_dim = self.model.out_dim + elif 'resnet' in model_id: + model_dict = { + 'resnet18': models.resnet18, + 'resnet34': models.resnet34, + 'resnet50': models.resnet50 + } + self.out_dim = out_dim = cfg.get('out_dim', 32) + self.model = model_dict[model_id](pretrained=False, norm_layer=nn.InstanceNorm2d) + self.model.fc = nn.Linear(self.model.fc.in_features, out_dim) + else: + raise ValueError('unknown map encoder!') + + def forward(self, x): + if self.normalize: + x = x * 2. - 1. + x = self.model(x) + x = self.dropout(x) + return x + +def compute_motion_mse(data, cfg): + diff = data['fut_motion_orig'] - data['train_dec_motion'] + # print(data['fut_motion_orig']) + # print(data['train_dec_motion']) + # zxc + if cfg.get('mask', True): + mask = data['fut_mask'] + diff *= mask.unsqueeze(2) + loss_unweighted = diff.pow(2).sum() + if cfg.get('normalize', True): + loss_unweighted /= diff.shape[0] + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + +def compute_z_kld(data, cfg): + loss_unweighted = data['q_z_dist'].kl(data['p_z_dist']).sum() + if cfg.get('normalize', True): + loss_unweighted /= data['batch_size'] + loss_unweighted = loss_unweighted.clamp_min_(cfg.min_clip) + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + +def compute_sample_loss(data, cfg): + diff = data['infer_dec_motion'] - data['fut_motion_orig'].unsqueeze(1) + if cfg.get('mask', True): + mask = data['fut_mask'].unsqueeze(1).unsqueeze(-1) + diff *= mask + dist = diff.pow(2).sum(dim=-1).sum(dim=-1) + loss_unweighted = dist.min(dim=1)[0] + if cfg.get('normalize', True): + loss_unweighted = loss_unweighted.mean() + else: + loss_unweighted = loss_unweighted.sum() + loss = loss_unweighted * cfg['weight'] + return loss, loss_unweighted + + +loss_func = { + 'mse': compute_motion_mse, + 'kld': compute_z_kld, + 'sample': compute_sample_loss +} \ No newline at end of file diff --git a/diffstack/models/base_models.py b/diffstack/models/base_models.py new file mode 100644 index 0000000..ce0b671 --- /dev/null +++ b/diffstack/models/base_models.py @@ -0,0 +1,1688 @@ +import numpy as np +import math +import textwrap +from collections import OrderedDict +from typing import Dict, Union, List +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models.resnet import resnet18, resnet50 +from torchvision.models.feature_extraction import create_feature_extractor +from torchvision.ops import RoIAlign +from diffstack.models.Transformer import SimpleTransformer + +from diffstack.utils.tensor_utils import reshape_dimensions, flatten +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.dynamics as dynamics + + +class MLP(nn.Module): + """ + Base class for simple Multi-Layer Perceptrons. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_dim (int): dimension of inputs + output_dim (int): dimension of outputs + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + super(MLP, self).__init__() + layers = [] + dim = input_dim + if layer_func_kwargs is None: + layer_func_kwargs = dict() + if dropouts is not None: + assert len(dropouts) == len(layer_dims) + for i, l in enumerate(layer_dims): + layers.append(layer_func(dim, l, **layer_func_kwargs)) + if normalization: + layers.append(nn.LayerNorm(l)) + layers.append(activation()) + if dropouts is not None and dropouts[i] > 0.0: + layers.append(nn.Dropout(dropouts[i])) + dim = l + layers.append(layer_func(dim, output_dim)) + if output_activation is not None: + if isinstance(output_activation, nn.Module): + layers.append(output_activation) + else: + layers.append(output_activation()) + self._layer_func = layer_func + self.nets = layers + self._model = nn.Sequential(*layers) + + self._layer_dims = layer_dims + self._input_dim = input_dim + self._output_dim = output_dim + self._dropouts = dropouts + self._act = activation + self._output_act = output_activation + + def output_shape(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + return [self._output_dim] + + def forward(self, inputs): + """ + Forward pass. + """ + return self._model(inputs) + + def __repr__(self): + """Pretty print network.""" + header = str(self.__class__.__name__) + act = None if self._act is None else self._act.__name__ + output_act = None if self._output_act is None else self._output_act.__name__ + + indent = " " * 4 + msg = "input_dim={}\noutput_shape={}\nlayer_dims={}\nlayer_func={}\ndropout={}\nact={}\noutput_act={}".format( + self._input_dim, + self.output_shape(), + self._layer_dims, + self._layer_func.__name__, + self._dropouts, + act, + output_act, + ) + msg = textwrap.indent(msg, indent) + msg = header + "(\n" + msg + "\n)" + return msg + + +class SplitMLP(MLP): + """ + A multi-output MLP network: The model split and reshapes the output layer to the desired output shapes + """ + + def __init__( + self, + input_dim: int, + output_shapes: OrderedDict, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_dim (int): dimension of inputs + output_shapes (dict): named dictionary of output shapes + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + + assert isinstance(output_shapes, OrderedDict) + output_dim = 0 + for v in output_shapes.values(): + output_dim += np.prod(v) + self._output_shapes = output_shapes + + super(SplitMLP, self).__init__( + input_dim=input_dim, + output_dim=output_dim, + layer_dims=layer_dims, + layer_func=layer_func, + layer_func_kwargs=layer_func_kwargs, + activation=activation, + dropouts=dropouts, + normalization=normalization, + output_activation=output_activation, + ) + + def output_shape(self, input_shape=None): + return self._output_shapes + + def forward(self, inputs): + outs = super(SplitMLP, self).forward(inputs) + out_dict = dict() + ind = 0 + for k, v in self._output_shapes.items(): + v_dim = int(np.prod(v)) + out_dict[k] = reshape_dimensions( + outs[:, ind : ind + v_dim], begin_axis=1, end_axis=2, target_dims=v + ) + ind += v_dim + return out_dict + + +class MIMOMLP(SplitMLP): + """ + A multi-input, multi-output MLP: The model flattens and concatenate the input before feeding into an MLP + """ + + def __init__( + self, + input_shapes: OrderedDict, + output_shapes: OrderedDict, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_shapes (OrderedDict): named dictionary of input shapes + output_shapes (OrderedDict): named dictionary of output shapes + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + assert isinstance(input_shapes, OrderedDict) + input_dim = 0 + for v in input_shapes.values(): + input_dim += np.prod(v) + + self._input_shapes = input_shapes + + super(MIMOMLP, self).__init__( + input_dim=input_dim, + output_shapes=output_shapes, + layer_dims=layer_dims, + layer_func=layer_func, + layer_func_kwargs=layer_func_kwargs, + activation=activation, + dropouts=dropouts, + normalization=normalization, + output_activation=output_activation, + ) + + def forward(self, inputs): + flat_inputs = [] + for k in self._input_shapes.keys(): + flat_inputs.append(flatten(inputs[k])) + return super(MIMOMLP, self).forward(torch.cat(flat_inputs, dim=1)) + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + # (Mohit): argh... forgot to remove this batchnorm + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + # (Mohit): argh... forgot to remove this batchnorm + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + """ + U-Net forward by concatenating input feature (x1) with mirroring encoding feature maps channel-wise (x2) + Args: + x1 (torch.Tensor): [B, C1, H1, W1] + x2 (torch.Tensor): [B, C2, H2, W2] + + Returns: + output (torch.Tensor): [B, out_channels, H2, W2] + """ + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class IdentityBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(IdentityBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += x + if self.final_relu: + out = F.relu(out) + return out + + +class ConvBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(ConvBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, filters3, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity(), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + if self.final_relu: + out = F.relu(out) + return out + + +class UNetDecoder(nn.Module): + """UNet part based on https://github.com/milesial/Pytorch-UNet/tree/master/unet""" + + def __init__( + self, + input_channel, + output_channel, + encoder_channels, + up_factor, + bilinear=True, + batchnorm=True, + ): + super(UNetDecoder, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d( + input_channel, 1024, kernel_size=3, stride=1, padding=1, bias=False + ), + nn.ReLU(True), + ) + + self.up1 = Up(1024 + encoder_channels[-1], 512, bilinear=True) + + self.up2 = Up(512 + encoder_channels[-2], 512 // up_factor, bilinear) + + self.up3 = Up(256 + encoder_channels[-3], 256 // up_factor, bilinear) + + self.layer1 = nn.Sequential( + ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.layer2 = nn.Sequential( + ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.layer3 = nn.Sequential( + ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.conv2 = nn.Sequential(nn.Conv2d(16, output_channel, kernel_size=1)) + + def forward( + self, + feat_to_decode: torch.Tensor, + encoder_feats: List[torch.Tensor], + target_hw: tuple, + ): + assert len(encoder_feats) >= 3 + x = self.conv1(feat_to_decode) + x = self.up1(x, encoder_feats[-1]) + x = self.up2(x, encoder_feats[-2]) + x = self.up3(x, encoder_feats[-3]) + + for layer in [self.layer1, self.layer2, self.layer3, self.conv2]: + x = layer(x) + + x = F.interpolate(x, size=(target_hw[0], target_hw[1]), mode="bilinear") + return x + + +class SpatialSoftmax(nn.Module): + """ + Spatial Softmax Layer. + Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al. + https://rll.berkeley.edu/dsae/dsae.pdf + """ + + def __init__( + self, + input_shape, + num_kp=None, + temperature=1.0, + learnable_temperature=False, + output_variance=False, + noise_std=0.0, + ): + """ + Args: + input_shape (list, tuple): shape of the input feature (C, H, W) + num_kp (int): number of keypoints (None for not use spatialsoftmax) + temperature (float): temperature term for the softmax. + learnable_temperature (bool): whether to learn the temperature + output_variance (bool): treat attention as a distribution, and compute second-order statistics to return + noise_std (float): add random spatial noise to the predicted keypoints + """ + super(SpatialSoftmax, self).__init__() + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape # (C, H, W) + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._num_kp = num_kp + else: + self.nets = None + self._num_kp = self._in_c + self.learnable_temperature = learnable_temperature + self.output_variance = output_variance + self.noise_std = noise_std + + if self.learnable_temperature: + # temperature will be learned + temperature = torch.nn.Parameter( + torch.ones(1) * temperature, requires_grad=True + ) + self.register_parameter("temperature", temperature) + else: + # temperature held constant after initialization + temperature = torch.nn.Parameter( + torch.ones(1) * temperature, requires_grad=False + ) + self.register_buffer("temperature", temperature) + + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) + pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h * self._in_w)).float() + pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h * self._in_w)).float() + self.register_buffer("pos_x", pos_x) + self.register_buffer("pos_y", pos_y) + + self.kps = None + + def __repr__(self): + """Pretty print network.""" + header = format(str(self.__class__.__name__)) + return header + "(num_kp={}, temperature={}, noise={})".format( + self._num_kp, self.temperature.item(), self.noise_std + ) + + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + assert len(input_shape) == 3 + assert input_shape[0] == self._in_c + return [self._num_kp, 2] + + def forward(self, feature): + """ + Forward pass through spatial softmax layer. For each keypoint, a 2D spatial + probability distribution is created using a softmax, where the support is the + pixel locations. This distribution is used to compute the expected value of + the pixel location, which becomes a keypoint of dimension 2. K such keypoints + are created. + Returns: + out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly + keypoint variance of shape [B, K, 2, 2] corresponding to the covariance + under the 2D spatial softmax distribution + """ + assert feature.shape[1] == self._in_c + assert feature.shape[2] == self._in_h + assert feature.shape[3] == self._in_w + if self.nets is not None: + feature = self.nets(feature) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + feature = feature.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(feature / self.temperature, dim=-1) + # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions + expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) + expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) + # stack to [B * K, 2] + expected_xy = torch.cat([expected_x, expected_y], 1) + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._num_kp, 2) + + if self.training: + noise = torch.randn_like(feature_keypoints) * self.noise_std + feature_keypoints += noise + + if self.output_variance: + # treat attention as a distribution, and compute second-order statistics to return + expected_xx = torch.sum( + self.pos_x * self.pos_x * attention, dim=1, keepdim=True + ) + expected_yy = torch.sum( + self.pos_y * self.pos_y * attention, dim=1, keepdim=True + ) + expected_xy = torch.sum( + self.pos_x * self.pos_y * attention, dim=1, keepdim=True + ) + var_x = expected_xx - expected_x * expected_x + var_y = expected_yy - expected_y * expected_y + var_xy = expected_xy - expected_x * expected_y + # stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix + feature_covar = torch.cat([var_x, var_xy, var_xy, var_y], 1).reshape( + -1, self._num_kp, 2, 2 + ) + feature_keypoints = (feature_keypoints, feature_covar) + + if isinstance(feature_keypoints, tuple): + self.kps = (feature_keypoints[0].detach(), feature_keypoints[1].detach()) + else: + self.kps = feature_keypoints.detach() + return feature_keypoints + + +class RasterizedMapEncoder(nn.Module): + """A basic image-based rasterized map encoder""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + feature_dim: int = None, + use_spatial_softmax=False, + spatial_softmax_kwargs=None, + output_activation=nn.ReLU, + ) -> None: + super().__init__() + self.model_arch = model_arch + self.num_input_channels = input_image_shape[0] + self._feature_dim = feature_dim + if output_activation is None: + self._output_activation = nn.Identity() + else: + self._output_activation = output_activation() + + # configure conv backbone + if model_arch == "resnet18": + self.map_model = resnet18() + out_h = int(math.ceil(input_image_shape[1] / 32.0)) + out_w = int(math.ceil(input_image_shape[2] / 32.0)) + self.conv_out_shape = (512, out_h, out_w) + elif model_arch == "resnet50": + self.map_model = resnet50() + out_h = int(math.ceil(input_image_shape[1] / 32.0)) + out_w = int(math.ceil(input_image_shape[2] / 32.0)) + self.conv_out_shape = (2048, out_h, out_w) + else: + raise NotImplementedError(f"Model arch {model_arch} unknown") + + # configure spatial reduction pooling layer + if use_spatial_softmax: + pooling = SpatialSoftmax( + input_shape=self.conv_out_shape, **spatial_softmax_kwargs + ) + self.pool_out_dim = int(np.prod(pooling.output_shape(self.conv_out_shape))) + else: + pooling = nn.AdaptiveAvgPool2d((1, 1)) + self.pool_out_dim = self.conv_out_shape[0] + + self.map_model.conv1 = nn.Conv2d( + self.num_input_channels, + 64, + kernel_size=(7, 7), + stride=(2, 2), + padding=(3, 3), + bias=False, + ) + self.map_model.avgpool = pooling + if feature_dim is not None: + self.map_model.fc = nn.Linear( + in_features=self.pool_out_dim, out_features=feature_dim + ) + else: + self.map_model.fc = nn.Identity() + + def output_shape(self, input_shape=None): + if self._feature_dim is not None: + return [self._feature_dim] + else: + return [self.pool_out_dim] + + def feature_channels(self): + if self.model_arch in ["resnet18", "resnet34"]: + channels = OrderedDict( + { + "layer1": 64, + "layer2": 128, + "layer3": 256, + "layer4": 512, + } + ) + else: + channels = OrderedDict( + { + "layer1": 256, + "layer2": 512, + "layer3": 1024, + "layer4": 2048, + } + ) + return channels + + def feature_scales(self): + return OrderedDict( + {"layer1": 1 / 4, "layer2": 1 / 8, "layer3": 1 / 16, "layer4": 1 / 32} + ) + + def forward(self, map_inputs) -> torch.Tensor: + feat = self.map_model(map_inputs) + feat = self._output_activation(feat) + return feat + + +class RotatedROIAlign(nn.Module): + def __init__(self, roi_feature_size, roi_scale): + super(RotatedROIAlign, self).__init__() + from diffstack.models.roi_align import ROI_align + + self.roi_align = lambda feat, rois: ROI_align(feat, rois, roi_feature_size[0]) + self.roi_scale = roi_scale + + def forward(self, feats, list_of_rois): + scaled_rois = [] + for rois in list_of_rois: + sroi = rois.clone() + sroi[..., :-1] = rois[..., :-1] * self.roi_scale + scaled_rois.append(sroi) + list_of_feats = self.roi_align(feats, scaled_rois) + return torch.cat(list_of_feats, dim=0) + + +class RasterizeROIEncoder(nn.Module): + """Use RoI Align to crop map feature for each agent""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + agent_feature_dim: int = None, + global_feature_dim: int = None, + roi_feature_size: tuple = (7, 7), + roi_layer_key: str = "layer4", + output_activation=nn.ReLU, + use_rotated_roi=True, + ) -> None: + super(RasterizeROIEncoder, self).__init__() + encoder = RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=global_feature_dim, + ) + feat_nodes = { + "map_model.layer1": "layer1", + "map_model.layer2": "layer2", + "map_model.layer3": "layer3", + "map_model.layer4": "layer4", + "map_model.fc": "final", + } + self.encoder_heads = create_feature_extractor(encoder, feat_nodes) + + self.roi_layer_key = roi_layer_key + roi_scale = encoder.feature_scales()[roi_layer_key] + roi_channel = encoder.feature_channels()[roi_layer_key] + if use_rotated_roi: + self.roi_align = RotatedROIAlign(roi_feature_size, roi_scale=roi_scale) + else: + self.roi_align = RoIAlign( + output_size=roi_feature_size, + spatial_scale=roi_scale, + sampling_ratio=-1, + aligned=True, + ) + + self.activation = output_activation() + if agent_feature_dim is not None: + self.agent_net = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), # [B, C, 1, 1] + nn.Flatten(start_dim=1), + nn.Linear(roi_channel, agent_feature_dim), + self.activation, + ) + else: + self.agent_net = nn.Identity() + + def forward(self, map_inputs: torch.Tensor, rois: torch.Tensor): + """ + + Args: + map_inputs (torch.Tensor): [B, C, H, W] + rois (torch.Tensor): [num_boxes, 5] + + Returns: + agent_feats (torch.Tensor): [num_boxes, ...] + roi_feats (torch.Tensor): [num_boxes, ...] + global_feats (torch.Tensor): [B, ....] + """ + feats = self.encoder_heads(map_inputs) + pre_roi_feats = feats[self.roi_layer_key] + roi_feats = self.roi_align(pre_roi_feats, rois) # [num_boxes, C, H, W] + agent_feats = self.agent_net(roi_feats) + global_feats = feats["final"] + + return agent_feats, roi_feats, self.activation(global_feats), feats + + +class RasterizedMapKeyPointNet(RasterizedMapEncoder): + """Predict a list of keypoints [x, y] given a rasterized map""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + spatial_softmax_kwargs=None, + ) -> None: + super().__init__( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=None, + use_spatial_softmax=True, + spatial_softmax_kwargs=spatial_softmax_kwargs, + ) + + def forward(self, map_inputs) -> torch.Tensor: + outputs = super(RasterizedMapKeyPointNet, self).forward(map_inputs) + # reshape back to kp format + return outputs.reshape(outputs.shape[0], -1, 2) + + +class RasterizedMapUNet(nn.Module): + """Predict a spatial map same size as the input rasterized map""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + output_channel=4, + use_spatial_softmax=False, + spatial_softmax_kwargs=None, + ) -> None: + super(RasterizedMapUNet, self).__init__() + encoder = RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + use_spatial_softmax=use_spatial_softmax, + spatial_softmax_kwargs=spatial_softmax_kwargs, + ) + self.input_image_shape = input_image_shape + # build graph for extracting intermediate features + feat_nodes = { + "map_model.layer1": "layer1", + "map_model.layer2": "layer2", + "map_model.layer3": "layer3", + "map_model.layer4": "layer4", + } + self.encoder_heads = create_feature_extractor(encoder, feat_nodes) + encoder_channels = list(encoder.feature_channels().values()) + self.decoder = UNetDecoder( + input_channel=encoder_channels[-1], + encoder_channels=encoder_channels[:-1], + output_channel=output_channel, + up_factor=2, + ) + + def forward(self, map_inputs, encoder_feats=None): + if encoder_feats is None: + encoder_feats = self.encoder_heads(map_inputs) + encoder_feats = [ + encoder_feats[k] for k in ["layer1", "layer2", "layer3", "layer4"] + ] + return self.decoder.forward( + feat_to_decode=encoder_feats[-1], + encoder_feats=encoder_feats[:-1], + target_hw=self.input_image_shape[1:], + ) + + +class RNNTrajectoryEncoder(nn.Module): + def __init__( + self, + trajectory_dim, + rnn_hidden_size, + feature_dim=None, + mlp_layer_dims: tuple = (), + mode="last", + ): + super(RNNTrajectoryEncoder, self).__init__() + self.mode = mode + self.lstm = nn.LSTM( + trajectory_dim, hidden_size=rnn_hidden_size, batch_first=True + ) + if feature_dim is not None: + self.mlp = MLP( + input_dim=rnn_hidden_size, + output_dim=feature_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + self._feature_dim = feature_dim + else: + self.mlp = nn.Identity() + self._feature_dim = rnn_hidden_size + + def output_shape(self, input_shape=None): + return [self._feature_dim] + + def forward(self, input_trajectory): + if self.mode == "last": + traj_feat = self.lstm(input_trajectory)[0][:, -1, :] + elif self.mode == "all": + traj_feat = self.lstm(input_trajectory)[0] + traj_feat = self.mlp(traj_feat) + return traj_feat + + +class RNNFeatureRoller(nn.Module): + def __init__(self, trajectory_dim, feature_dim): + super(RNNFeatureRoller, self).__init__() + self.gru = nn.GRU(trajectory_dim, hidden_size=feature_dim, batch_first=True) + self._feature_dim = feature_dim + + def output_shape(self, input_shape=None): + return [self._feature_dim] + + def forward(self, feature, input_trajectory): + _, hn = self.gru(input_trajectory, feature.unsqueeze(0)) + return hn[0] + feature + + +class PosteriorEncoder(nn.Module): + """Posterior Encoder (x, x_c -> q) for CVAE""" + + def __init__( + self, + condition_dim: int, + trajectory_shape: tuple, # [T, D] + output_shapes: OrderedDict, + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + ) -> None: + super(PosteriorEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_dim=trajectory_shape[-1], rnn_hidden_size=rnn_hidden_size + ) + self.mlp = SplitMLP( + input_dim=(rnn_hidden_size + condition_dim), + output_shapes=output_shapes, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs, condition_features) -> Dict[str, torch.Tensor]: + traj_feat = self.traj_encoder(inputs["trajectories"]) + feat = torch.cat((traj_feat, condition_features), dim=-1) + return self.mlp(feat) + + +class ScenePosteriorEncoder(nn.Module): + """Scene Posterior Encoder (x, x_c -> q) for CVAE""" + + def __init__( + self, + condition_dim: int, + trajectory_shape: tuple, # [T, D] + output_shapes: OrderedDict, + aggregate_func="max", + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + ) -> None: + super(ScenePosteriorEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + self.transformer = SimpleTransformer(src_dim=rnn_hidden_size + condition_dim) + self.aggregate_func = aggregate_func + + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_dim=trajectory_shape[-1], rnn_hidden_size=rnn_hidden_size + ) + self.mlp = SplitMLP( + input_dim=(rnn_hidden_size + condition_dim), + output_shapes=output_shapes, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs, condition_features, mask, pos) -> Dict[str, torch.Tensor]: + bs, Na, T = inputs["trajectories"].shape[:3] + + traj_feat = self.traj_encoder( + TensorUtils.join_dimensions(inputs["trajectories"], 0, 2) + ) + feat = torch.cat( + (traj_feat, TensorUtils.join_dimensions(condition_features, 0, 2)), dim=-1 + ).reshape(bs, Na, -1) + + feat = self.transformer(feat, mask, pos) + feat + if self.aggregate_func == "max": + feat = feat.max(1)[0] + elif self.aggregate_func == "mean": + feat = feat.mean(1) + + return self.mlp(feat) + + +class ConditionEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + map_encoder: nn.Module, + trajectory_shape: tuple, # [T, D] + condition_dim: int, + mlp_layer_dims: tuple = (128, 128), + goal_encoder=None, + agent_traj_encoder=None, + ) -> None: + super(ConditionEncoder, self).__init__() + self.map_encoder = map_encoder + self.trajectory_shape = trajectory_shape + self.goal_encoder = goal_encoder + self.agent_traj_encoder = agent_traj_encoder + + visual_feature_size = self.map_encoder.output_shape()[0] + self.mlp = MLP( + input_dim=visual_feature_size, + output_dim=condition_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, condition_inputs): + map_feat = self.map_encoder(condition_inputs["image"]) + c_feat = self.mlp(map_feat) + if self.goal_encoder is not None: + goal_feat = self.goal_encoder(condition_inputs["goal"]) + c_feat = torch.cat([c_feat, goal_feat], dim=-1) + if self.agent_traj_encoder is not None and "agent_traj" in condition_inputs: + agent_traj_feat = self.agent_traj_encoder(condition_inputs["agent_traj"]) + c_feat = torch.cat([c_feat, agent_traj_feat], dim=-1) + return c_feat + + +class ECEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + map_encoder, + trajectory_shape: tuple, # [T, D] + EC_dim: int, + condition_dim: int, + mlp_layer_dims: tuple = (128, 128), + goal_encoder=None, + rnn_hidden_size: int = 100, + ) -> None: + super(ECEncoder, self).__init__() + if isinstance(map_encoder, nn.Module): + self.map_encoder = map_encoder + visual_feature_size = self.map_encoder.output_shape()[0] + elif isinstance(map_encoder, int): + visual_feature_size = map_encoder + self.map_encoder = None + self.trajectory_shape = trajectory_shape + self.goal_encoder = goal_encoder + self.EC_dim = EC_dim + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_shape[-1], rnn_hidden_size, feature_dim=EC_dim + ) + goal_dim = 0 if goal_encoder is None else goal_encoder.output_shape()[0] + self.mlp = MLP( + input_dim=visual_feature_size + goal_dim + EC_dim, + output_dim=condition_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, condition_inputs): + if self.map_encoder is None: + c_feat = condition_inputs["map_feature"] + else: + c_feat = self.map_encoder(condition_inputs["image"]) + if self.goal_encoder is not None: + goal_feat = self.goal_encoder(condition_inputs["goal"]) + c_feat = torch.cat([c_feat, goal_feat], dim=-1) + if ( + "cond_traj" in condition_inputs + and condition_inputs["cond_traj"] is not None + ): + if condition_inputs["cond_traj"].ndim == 3: + EC_feat = self.traj_encoder(condition_inputs["cond_traj"]).unsqueeze(1) + elif condition_inputs["cond_traj"].ndim == 4: + bs, M, T, D = condition_inputs["cond_traj"].shape + EC_feat = self.traj_encoder( + condition_inputs["cond_traj"].reshape(-1, T, D) + ).reshape(bs, M, -1) + + EC_feat = EC_feat * (condition_inputs["cond_traj"][..., 0] != 0).any( + -1 + ).unsqueeze(-1) + EC_feat = torch.cat( + (EC_feat, torch.zeros([bs, 1, self.EC_dim]).to(EC_feat.device)), 1 + ) + else: + bs = c_feat.shape[0] + EC_feat = torch.zeros([bs, 1, self.EC_dim]).to(c_feat.device) + if c_feat.ndim == 2: + c_feat = c_feat.unsqueeze(1).repeat(1, EC_feat.shape[1], 1) + else: + assert c_feat.ndim == 3 and c_feat.shape[1] == EC_feat.shape[1] + c_feat = torch.cat((c_feat, EC_feat), -1) + c_feat = self.mlp(c_feat) + + return c_feat + + +class AgentTrajEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + trajectory_shape: tuple, # [T, D] + feature_dim: int, + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + use_transformer=True, + ) -> None: + super(AgentTrajEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + self.feature_dim = feature_dim + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_shape[-1], + rnn_hidden_size, + feature_dim=feature_dim, + mlp_layer_dims=mlp_layer_dims, + ) + if use_transformer: + self.transformer = SimpleTransformer(src_dim=feature_dim) + else: + self.transformer = None + + def forward(self, agent_trajs): + bs, M, T, D = agent_trajs.shape + feat = self.traj_encoder(agent_trajs.reshape(-1, T, D)).reshape(bs, M, -1) + if self.transformer is not None: + agent_pos = agent_trajs[:, :, 0, :2] + avails = (agent_pos != 0).any(-1) + feat = self.transformer(feat, avails, agent_pos) + + return torch.max(feat, 1)[0] + + +class PosteriorNet(nn.Module): + def __init__( + self, + input_shapes: OrderedDict, + condition_dim: int, + param_shapes: OrderedDict, + mlp_layer_dims: tuple = (), + ): + super(PosteriorNet, self).__init__() + all_shapes = deepcopy(input_shapes) + all_shapes["condition_features"] = (condition_dim,) + self.mlp = MIMOMLP( + input_shapes=all_shapes, + output_shapes=param_shapes, + layer_dims=mlp_layer_dims, + output_activation=None, + ) + + def forward(self, inputs: dict, condition_features: torch.Tensor): + all_inputs = dict(inputs) + all_inputs["condition_features"] = condition_features + return self.mlp(all_inputs) + + +class ConditionNet(nn.Module): + def __init__( + self, + condition_input_shapes: OrderedDict, + condition_dim: int, + mlp_layer_dims: tuple = (), + ): + super(ConditionNet, self).__init__() + self.mlp = MIMOMLP( + input_shapes=condition_input_shapes, + output_shapes=OrderedDict(feat=(condition_dim,)), + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs: dict): + return self.mlp(inputs)["feat"] + + +class ConditionDecoder(nn.Module): + """Decoding (z, c) -> x' using a flat MLP""" + + def __init__(self, decoder_model: nn.Module): + super(ConditionDecoder, self).__init__() + self.decoder_model = decoder_model + + def forward(self, latents, condition_features, **decoder_kwargs): + return self.decoder_model( + torch.cat((latents, condition_features), dim=-1), **decoder_kwargs + ) + + +class TrajectoryDecoder(nn.Module): + def __init__( + self, + feature_dim: int, + state_dim: int = 3, + num_steps: int = None, + dynamics_type: Union[str, dynamics.DynType] = None, + dynamics_kwargs: dict = None, + step_time: float = None, + network_kwargs: dict = None, + Gaussian_var=False, + ): + """ + A class that predict future trajectories based on input features + Args: + feature_dim (int): dimension of the input feature + state_dim (int): dimension of the output trajectory at each step + num_steps (int): (optional) number of future state to predict + dynamics_type (str, dynamics.DynType): (optional) if specified, the network predicts action + for the dynamics model instead of future states. The actions are then used to predict + the future trajectories. + step_time (float): time between steps. required for using dynamics models + network_kwargs (dict): keyword args for the decoder networks + Gaussian_var (bool): whether output the variance of the predicted trajectory + """ + super(TrajectoryDecoder, self).__init__() + self.feature_dim = feature_dim + self.state_dim = state_dim + self.num_steps = num_steps + self.step_time = step_time + self._network_kwargs = network_kwargs + self._dynamics_type = dynamics_type + self._dynamics_kwargs = dynamics_kwargs + self.Gaussian_var = Gaussian_var + self._create_dynamics() + self._create_networks() + + def _create_dynamics(self): + if self._dynamics_type in ["Unicycle", dynamics.DynType.UNICYCLE]: + self.dyn = dynamics.Unicycle( + self.step_time, + max_steer=self._dynamics_kwargs["max_steer"], + max_yawvel=self._dynamics_kwargs["max_yawvel"], + acce_bound=self._dynamics_kwargs["acce_bound"], + ) + elif self._dynamics_type in ["Bicycle", dynamics.DynType.BICYCLE]: + self.dyn = dynamics.Bicycle( + acc_bound=self._dynamics_kwargs["acce_bound"], + ddh_bound=self._dynamics_kwargs["ddh_bound"], + max_hdot=self._dynamics_kwargs["max_yawvel"], + max_speed=self._dynamics_kwargs["max_speed"], + ) + else: + self.dyn = None + + def _create_networks(self): + raise NotImplementedError + + def _forward_networks(self, inputs, current_states=None, num_steps=None): + raise NotImplementedError + + def _forward_dynamics(self, current_states, actions): + assert self.dyn is not None + assert current_states.shape[-1] == self.dyn.xdim + assert actions.shape[-1] == self.dyn.udim + assert isinstance(self.step_time, float) and self.step_time > 0 + x = self.dyn.forward_dynamics( + x0=current_states, + u=actions, + bound=False, + ) + pos = self.dyn.state2pos(x) + yaw = self.dyn.state2yaw(x) + traj = torch.cat((pos, yaw), dim=-1) + return traj, x + + def forward(self, inputs, current_states=None, num_steps=None): + preds = self._forward_networks( + inputs, current_states=current_states, num_steps=num_steps + ) + if self.dyn is not None: + preds["controls"] = preds["trajectories"] + preds["trajectories"], x = self._forward_dynamics( + current_states=current_states, actions=preds["trajectories"] + ) + preds["terminal_state"] = x[..., -1, :] + return preds + + +class MLPTrajectoryDecoder(TrajectoryDecoder): + def _create_networks(self): + net_kwargs = ( + dict() if self._network_kwargs is None else dict(self._network_kwargs) + ) + if self._network_kwargs is None: + net_kwargs = dict() + + assert isinstance(self.num_steps, int) + if self.dyn is None: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.state_dim)) + else: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.dyn.udim)) + if self.Gaussian_var: + pred_shapes["logvar"] = (self.num_steps, self.state_dim) + + state_as_input = net_kwargs.pop("state_as_input") + if self.dyn is not None: + assert state_as_input # TODO: deprecated, set default to True and remove from configs + + if state_as_input and self.dyn is not None: + feature_dim = self.feature_dim + self.dyn.xdim + else: + feature_dim = self.feature_dim + + self.mlp = SplitMLP( + input_dim=feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + + def _forward_networks(self, inputs, current_states=None, num_steps=None): + if self._network_kwargs["state_as_input"] and self.dyn is not None: + inputs = torch.cat((inputs, current_states), dim=-1) + + if inputs.ndim == 2: + # [B, D] + preds = self.mlp(inputs) + elif inputs.ndim == 3: + # [B, A, D] + preds = TensorUtils.time_distributed(inputs, self.mlp) + else: + raise ValueError( + "Expecting inputs to have ndim == 2 or 3, got {}".format(inputs.ndim) + ) + return preds + + +class MLPECTrajectoryDecoder(TrajectoryDecoder): + def __init__( + self, + feature_dim: int, + state_dim: int = 3, + num_steps: int = None, + dynamics_type: Union[str, dynamics.DynType] = None, + dynamics_kwargs: dict = None, + step_time: float = None, + EC_feature_dim=64, + network_kwargs: dict = None, + Gaussian_var=False, + ): + """ + A class that predict future trajectories based on input features + Args: + feature_dim (int): dimension of the input feature + state_dim (int): dimension of the output trajectory at each step + num_steps (int): (optional) number of future state to predict + dynamics_type (str, dynamics.DynType): (optional) if specified, the network predicts action + for the dynamics model instead of future states. The actions are then used to predict + the future trajectories. + step_time (float): time between steps. required for using dynamics models + network_kwargs (dict): keyword args for the decoder networks + Gaussian_var (bool): whether output the variance of the predicted trajectory + """ + super(TrajectoryDecoder, self).__init__() + self.feature_dim = feature_dim + self.state_dim = state_dim + self.num_steps = num_steps + self.step_time = step_time + self.EC_feature_dim = EC_feature_dim + self._network_kwargs = network_kwargs + self._dynamics_type = dynamics_type + self._dynamics_kwargs = dynamics_kwargs + self.Gaussian_var = Gaussian_var + self._create_dynamics() + self._create_networks() + + def _create_networks(self): + net_kwargs = ( + dict() if self._network_kwargs is None else dict(self._network_kwargs) + ) + if self._network_kwargs is None: + net_kwargs = dict() + + assert isinstance(self.num_steps, int) + if self.dyn is None: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.state_dim)) + else: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.dyn.udim)) + if self.Gaussian_var: + pred_shapes["logvar"] = (self.num_steps, self.state_dim) + + state_as_input = net_kwargs.pop("state_as_input") + if self.dyn is not None: + assert state_as_input # TODO: deprecated, set default to True and remove from configs + + if state_as_input and self.dyn is not None: + feature_dim = self.feature_dim + self.dyn.xdim + else: + feature_dim = self.feature_dim + + self.mlp = SplitMLP( + input_dim=feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + self.offsetmlp = SplitMLP( + input_dim=feature_dim + self.EC_feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + + def _forward_networks( + self, inputs, EC_feat=None, current_states=None, num_steps=None + ): + if self._network_kwargs["state_as_input"] and self.dyn is not None: + inputs = torch.cat((inputs, current_states), dim=-1) + if inputs.ndim == 2: + # [B, D] + + preds = self.mlp(inputs) + if EC_feat is not None: + bs, M = EC_feat.shape[:2] + + # EC_feat = self.traj_encoder(cond_traj.reshape(-1,T,D)).reshape(bs,M,-1) + inputs_tile = inputs.unsqueeze(1).tile(1, M, 1) + EC_feat = torch.cat((inputs_tile, EC_feat), dim=-1) + EC_preds = TensorUtils.time_distributed(EC_feat, self.offsetmlp) + EC_preds["trajectories"] = EC_preds["trajectories"] + preds[ + "trajectories" + ].unsqueeze(1) + else: + EC_preds = None + + elif inputs.ndim == 3: + # [B, A, D] + preds = TensorUtils.time_distributed(inputs, self.mlp) + if EC_feat is not None: + assert EC_feat.ndim == 4 + bs, A, M = EC_feat.shape[:3] + # EC_feat = self.traj_encoder(cond_traj.reshape(-1,T,D)).reshape(bs,M,-1) + inputs_tile = inputs.tile(1, M, 1) + EC_feat = torch.cat((inputs_tile, EC_feat), dim=-1) + EC_preds = TensorUtils.time_distributed(EC_feat, self.offsetmlp) + EC_preds = reshape_dimensions(EC_preds, 1, 2, (A, M)) + EC_preds["trajectories"] = EC_preds["trajectories"] + preds[ + "trajectories" + ].unsqueeze(2) + else: + EC_preds = None + else: + raise ValueError( + "Expecting inputs to have ndim == 2 or 3, got {}".format(inputs.ndim) + ) + return preds, EC_preds + + def _forward_dynamics(self, current_states, actions, **kwargs): + assert self.dyn is not None + assert current_states.shape[-1] == self.dyn.xdim + assert actions.shape[-1] == self.dyn.udim + assert isinstance(self.step_time, float) and self.step_time > 0 + + x = self.dyn.forward_dynamics( + x0=current_states, + u=actions, + bound=False, + ) + pos = self.dyn.state2pos(x) + yaw = self.dyn.state2yaw(x) + traj = torch.cat((pos, yaw), dim=-1) + return traj + + def forward(self, inputs, current_states=None, EC_feat=None, num_steps=None): + preds, EC_preds = self._forward_networks( + inputs, EC_feat, current_states=current_states, num_steps=num_steps + ) + if self.dyn is not None: + preds["controls"] = preds["trajectories"] + if EC_preds is None: + preds["trajectories"] = self._forward_dynamics( + current_states=current_states, actions=preds["trajectories"] + ) + else: + total_actions = torch.cat( + (preds["trajectories"].unsqueeze(1), EC_preds["trajectories"]), 1 + ) + bs, A, T, D = total_actions.shape + current_states_tiled = current_states.unsqueeze(1).repeat( + 1, total_actions.size(1), 1 + ) + total_trajectories = self._forward_dynamics( + current_states=current_states_tiled.reshape(bs * A, -1), + actions=total_actions.reshape(bs * A, T, D), + ).reshape(*total_actions.shape[:-1], -1) + preds["trajectories"] = total_trajectories[:, 0] + preds["EC_trajectories"] = total_trajectories[:, 1:] + else: + preds["EC_trajectories"] = EC_preds["trajectories"] + + return preds + + +class clique_gibbs_distr(nn.Module): + def __init__( + self, + state_enc_dim, + edge_encoding_dim, + z_dim, + edge_types, + node_types, + node_hidden_dim=[64, 64], + edge_hidden_dim=[64, 64], + ): + super(clique_gibbs_distr, self).__init__() + self.edge_encoding_dim = edge_encoding_dim + self.z_dim = z_dim + self.state_enc_dim = state_enc_dim + self.node_types = node_types + self.edge_types = edge_types + self.et_name = dict() + for et in self.edge_types: + self.et_name[et] = et[0].name + "->" + et[1].name + + self.node_factor = nn.ModuleDict() + self.edge_factor = nn.ModuleDict() + for node_type in self.node_types: + self.node_factor[node_type.name] = MLP( + state_enc_dim[node_type], z_dim, node_hidden_dim + ) + + for edge_type in self.edge_types: + self.edge_factor[self.et_name[edge_type]] = MLP( + edge_encoding_dim[edge_type], z_dim**2, edge_hidden_dim + ) + + def forward( + self, + node_types, + node_enc, + edge_enc, + clique_index, + clique_node_index, + clique_order_index, + clique_is_robot=None, + ): + device = node_enc.device + bs, Na = node_types.shape[:2] + Nc, max_cs = clique_node_index.shape[1:] + + # clique_edge_mask = (clique_index.unsqueeze(1)==clique_index.unsqueeze(2))*(clique_index>=0).unsqueeze(1) + # clique_edge_mask = clique_edge_mask*(~torch.eye(Na,dtype=torch.bool,device=device).unsqueeze(0)) + node_factor = torch.zeros([bs, Na, self.z_dim]).to(device) + for nt in self.node_types: + node_factor += self.node_factor[nt.name](node_enc) * ( + node_types == nt.value + ).unsqueeze(-1) + + edge_factor = torch.zeros([bs, Na, Na, self.z_dim**2]).to(device) + for et in self.edge_types: + et_flag = (node_types == et[0].value).unsqueeze(1) * ( + node_types == et[1].value + ).unsqueeze(2) + edge_factor += self.edge_factor[self.et_name[et]]( + edge_enc[..., -1, :] + ) * et_flag.unsqueeze(-1) + edge_factor = edge_factor.reshape(bs, Na, Na, self.z_dim, self.z_dim) + + node_factor_extended = torch.cat( + (node_factor, torch.zeros_like(node_factor[:, 0:1])), 1 + ) + edge_factor_extended = torch.cat( + ( + torch.cat( + ( + edge_factor, + torch.zeros([bs, 1, Na, self.z_dim, self.z_dim], device=device), + ), + 1, + ), + torch.zeros([bs, Na + 1, 1, self.z_dim, self.z_dim], device=device), + ), + 2, + ) + edge_factor_extended_flat = edge_factor_extended.reshape( + bs, -1, self.z_dim, self.z_dim + ) + + logpi = torch.zeros([bs, Nc, *(max_cs * [self.z_dim])]).to(device) + z = torch.stack(torch.meshgrid(*(max_cs * [torch.arange(self.z_dim)])), -1).to( + device + ) + + # add the node factor to the Gibbs distribution + idx = ( + clique_node_index.reshape(bs, -1) + .unsqueeze(-1) + .repeat_interleave(self.z_dim, -1) + ) + idx = idx.masked_fill(idx == -1, Na) + node_factor_clique = torch.gather(node_factor_extended, 1, idx).reshape( + bs, Nc, max_cs, self.z_dim + ) + for i in range(max_cs): + logpi += node_factor_clique[..., i, :].reshape( + [bs, Nc, *(i * [1]), self.z_dim, *((max_cs - i - 1) * [1])] + ) + + # add the edge factor to the Gibbs distribution + for i in range(max_cs - 1): + for j in range(i + 1, max_cs): + idx = clique_node_index[:, :, [i, j]] + idx = idx.masked_fill(idx == -1, Na) + idx_flat = idx[..., 0] * (Na + 1) + idx[..., 1] + idx_flat = ( + idx_flat.reshape(bs, -1, 1, 1) + .repeat_interleave(self.z_dim, 2) + .repeat_interleave(self.z_dim, 3) + ) + edge_factor_ij = torch.gather(edge_factor_extended_flat, 1, idx_flat) + logpi += edge_factor_ij.reshape( + bs, + Nc, + *(i * [1]), + self.z_dim, + *((j - i - 1) * [1]), + self.z_dim, + *((max_cs - j - 1) * [1]), + ) + + return logpi, z + + +class AdditiveAttention(nn.Module): + # Implementing the attention module of Bahdanau et al. 2015 where + # score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1)) + def __init__( + self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None + ): + super(AdditiveAttention, self).__init__() + self.encoder_hidden_state_dim = encoder_hidden_state_dim + self.decoder_hidden_state_dim = decoder_hidden_state_dim + if internal_dim is None: + internal_dim = int( + (encoder_hidden_state_dim + decoder_hidden_state_dim) / 2 + ) + self.internal_dim = internal_dim + self.w1 = nn.Linear(encoder_hidden_state_dim, internal_dim, bias=False) + self.w2 = nn.Linear(decoder_hidden_state_dim, internal_dim, bias=False) + self.v = nn.Linear(internal_dim, 1, bias=False) + + def score(self, encoder_state, decoder_state): + # encoder_state is of shape (batch, enc_dim) + # decoder_state is of shape (batch, dec_dim) + # return value should be of shape (batch, 1) + return self.v(torch.tanh(self.w1(encoder_state) + self.w2(decoder_state))) + + def forward(self, encoder_states, decoder_state): + # encoder_states is of shape (batch, num_enc_states, enc_dim) + # decoder_state is of shape (batch, dec_dim) + score_vec = torch.cat( + [ + self.score(encoder_states[:, i], decoder_state) + for i in range(encoder_states.shape[1]) + ], + dim=1, + ) + # score_vec is of shape (batch, num_enc_states) + + attention_probs = torch.unsqueeze(F.softmax(score_vec, dim=1), dim=2) + # attention_probs is of shape (batch, num_enc_states, 1) + + final_context_vec = torch.sum(attention_probs * encoder_states, dim=1) + # final_context_vec is of shape (batch, enc_dim) + + return final_context_vec, attention_probs + + +if __name__ == "__main__": + # model = RasterizedMapUNet(model_arch="resnet18", input_image_shape=(15, 224, 224), output_channel=4) + t = torch.randn(2, 15, 224, 224) + centers = torch.randint(low=10, high=214, size=(10, 2)).float() + yaws = torch.zeros(size=(10, 1)) + extents = torch.ones(10, 2) * 2 + from diffstack.utils.geometry_utils import ( + get_box_world_coords, + get_square_upright_box, + ) + + boxes = get_box_world_coords(pos=centers, yaw=yaws, extent=extents) + boxes_aligned = torch.flatten(boxes[:, [0, 2], :], start_dim=1) + box_aligned = get_square_upright_box(centers, 2) + from IPython import embed + + embed() + boxes_indices = torch.zeros(10, 1) + boxes_indices[5:] = 1 + boxes_indexed = torch.cat((boxes_indices, boxes_aligned), dim=1) + + model = RasterizeROIEncoder(model_arch="resnet50", input_image_shape=(15, 224, 224)) + output = model(t, boxes_indexed) + for f in output: + print(f.shape) diff --git a/diffstack/models/cnn_roi_encoder.py b/diffstack/models/cnn_roi_encoder.py new file mode 100644 index 0000000..ea197f9 --- /dev/null +++ b/diffstack/models/cnn_roi_encoder.py @@ -0,0 +1,558 @@ +from logging import raiseExceptions +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffstack.utils.geometry_utils import batch_nd_transform_points + + +class CNNROIMapEncoder(nn.Module): + def __init__( + self, + map_channels, + hidden_channels, + ROI_outdim, + output_size, + kernel_size, + strides, + input_size, + ): + """ + multi-layer CNN with ROI align for the output + Args: + map_channels (int): map channel numbers + ROI (list): list of ROIs + ROI_outdim (int): ROI points along each dim total interpolating points: ROI_outdim x ROI_outdim + output_size (int): output feature size + kernel_size (list): CNN kernel size for each layer + strides (list): CNN strides for each layer + input_size (tuple): map size + + """ + super(CNNROIMapEncoder, self).__init__() + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + self.num_channel_last = hidden_channels[-1] + self.ROI_outdim = ROI_outdim + x_dummy = torch.ones([map_channels, *input_size]).unsqueeze(0) * torch.tensor( + float("nan") + ) + + for i, hidden_size in enumerate(hidden_channels): + self.convs.append( + nn.Conv2d( + map_channels if i == 0 else hidden_channels[i - 1], + hidden_channels[i], + kernel_size[i], + stride=strides[i], + padding=int((kernel_size[i] - 1) / 2), + ) + ) + self.bns.append(nn.BatchNorm2d(hidden_size)) + x_dummy = self.convs[i](x_dummy) + + "fully connected layer after ROI align" + self.fc = nn.Linear( + ROI_outdim * ROI_outdim * self.num_channel_last, output_size + ) + + def forward(self, x, ROI): + """ + + Args: + x (torch.tensor): image + ROI (list): ROIs + + Returns: + out (list): ROI align result for each ROI + """ + + for conv, bn in zip(self.convs, self.bns): + x0 = x + x = F.leaky_relu(conv(x), 0.2) + x = bn(x) + x = ROI_align(x, ROI, self.ROI_outdim) + out = [None] * len(x) + for i in range(len(x)): + out[i] = self.fc(x[i].flatten(start_dim=-3)) + + return out + + +# def bilinear_interpolate(img, x, y, floattype=torch.float): +# """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img +# Args: +# img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer +# x (torch.Tensor): Float dtype, x axis location for sampling +# y (torch.Tensor): Float dtype, y axis location for sampling +# batched version + +# Returns: +# torch.Tensor: interpolated value +# """ +# bs = img.size(0) +# x0 = torch.floor(x).type(torch.cuda.LongTensor) +# x1 = x0 + 1 + +# y0 = torch.floor(y).type(torch.cuda.LongTensor) +# y1 = y0 + 1 + +# x0 = torch.clamp(x0, 0, img.shape[-2] - 1) +# x1 = torch.clamp(x1, 0, img.shape[-2] - 1) +# y0 = torch.clamp(y0, 0, img.shape[-1] - 1) +# y1 = torch.clamp(y1, 0, img.shape[-1] - 1) + +# Ia = [None] * bs +# Ib = [None] * bs +# Ic = [None] * bs +# Id = [None] * bs +# for i in range(bs): +# Ia[i] = img[i, ..., y0[i], x0[i]] +# Ib[i] = img[i, ..., y1[i], x0[i]] +# Ic[i] = img[i, ..., y0[i], x1[i]] +# Id[i] = img[i, ..., y1[i], x1[i]] + +# Ia = torch.stack(Ia, dim=0) +# Ib = torch.stack(Ib, dim=0) +# Ic = torch.stack(Ic, dim=0) +# Id = torch.stack(Id, dim=0) + +# step = (x1.type(floattype) - x0.type(floattype)) * ( +# y1.type(floattype) - y0.type(floattype) +# ) +# step = torch.clamp(step, 1e-3, 2) +# norm_const = 1 / step + +# wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const +# wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const +# wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const +# wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const +# return ( +# Ia * wa.unsqueeze(1) +# + Ib * wb.unsqueeze(1) +# + Ic * wc.unsqueeze(1) +# + Id * wd.unsqueeze(1) +# ) +def bilinear_interpolate(img, x, y, floattype=torch.float, flip_y=False): + """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img + Args: + img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer + x (torch.Tensor): Float dtype, x axis location for sampling + y (torch.Tensor): Float dtype, y axis location for sampling + + Returns: + torch.Tensor: interpolated value + """ + if flip_y: + y = img.shape[-2] - 1-y + if img.device.type == "cuda": + x0 = torch.floor(x).type(torch.cuda.LongTensor) + y0 = torch.floor(y).type(torch.cuda.LongTensor) + elif img.device.type == "cpu": + x0 = torch.floor(x).type(torch.LongTensor) + y0 = torch.floor(y).type(torch.LongTensor) + else: + raise ValueError("device not recognized") + x1 = x0 + 1 + y1 = y0 + 1 + + x0 = torch.clamp(x0, 0, img.shape[-1] - 1) + x1 = torch.clamp(x1, 0, img.shape[-1] - 1) + y0 = torch.clamp(y0, 0, img.shape[-2] - 1) + y1 = torch.clamp(y1, 0, img.shape[-2] - 1) + + Ia = img[..., y0, x0] + Ib = img[..., y1, x0] + Ic = img[..., y0, x1] + Id = img[..., y1, x1] + + step = (x1.type(floattype) - x0.type(floattype)) * ( + y1.type(floattype) - y0.type(floattype) + ) + step = torch.clamp(step, 1e-3, 2) + norm_const = 1 / step + + wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const + wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const + wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const + wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const + return ( + Ia * wa.unsqueeze(0) + + Ib * wb.unsqueeze(0) + + Ic * wc.unsqueeze(0) + + Id * wd.unsqueeze(0) + ) + + +def ROI_align(features, ROI, outdim): + """Given feature layers and proposals return bilinear interpolated + points in feature layer + + Args: + features (torch.Tensor): Tensor of shape channels x width x height + proposal (list of torch.Tensor): x0,y0,W1,W2,H1,H2,psi + """ + + bs, num_channels, h, w = features.shape + + xg = ( + torch.cat( + ( + torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, + torch.zeros([outdim, 1]), + ), + dim=-1, + ) + / outdim + ) + yg = ( + torch.cat( + ( + torch.zeros([outdim, 1]), + torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, + ), + dim=-1, + ) + / outdim + ) + gg = xg.view(1, -1, 2) + yg.view(-1, 1, 2) + gg = gg.to(features.device) + res = list() + for i in range(bs): + if ROI[i] is not None: + W1 = ROI[i][..., 2:3] + W2 = ROI[i][..., 3:4] + H1 = ROI[i][..., 4:5] + H2 = ROI[i][..., 5:6] + psi = ROI[i][..., 6:] + WH = torch.cat((W1 + W2, H1 + H2), dim=-1) + offset = torch.cat(((W1 - W2) / 2, (H1 - H2) / 2), dim=-1) + s = torch.sin(psi).unsqueeze(-1) + c = torch.cos(psi).unsqueeze(-1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + ggi = gg * WH[..., None, None, :] - offset[..., None, None, :] + ggi = ggi @ rotM[..., None, :, :] + ROI[i][..., None, None, 0:2] + + x_sample = ggi[..., 0].flatten() + y_sample = ggi[..., 1].flatten() + res.append( + bilinear_interpolate(features[i], x_sample, y_sample).view( + ggi.shape[0], num_channels, *ggi.shape[1:-1] + ) + ) + else: + res.append(None) + + return res + + +def generate_ROIs_deprecated( + pos, + yaw, + centroid, + scene_yaw, + raster_from_world, + mask, + patch_size, + mode="last", +): + """ + This version generates ROI for all agents only at most recent time step unless specified otherwise + """ + if mode == "all": + bs = pos.shape[0] + yaw = yaw.type(torch.float) + scene_yaw = scene_yaw.type(torch.float) + s = torch.sin(scene_yaw).reshape(-1, 1, 1, 1) + c = torch.cos(scene_yaw).reshape(-1, 1, 1, 1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + world_xy = ((pos.unsqueeze(-2)) @ (rotM.transpose(-1, -2))).squeeze(-2) + world_xy += centroid.view(-1, 1, 1, 2).type(torch.float) + + Mat = raster_from_world.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(world_xy, Mat) + raster_mult = torch.linalg.norm( + raster_from_world[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii, jj = torch.where(mask[i]) + index[i] = (ii, jj) + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, jj], + patches_size, + yaw[i, ii, jj], + ), + dim=-1, + ).to(pos.device) + return ROI, index + elif mode == "last": + num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) + nummask = num * mask + last_idx, _ = torch.max(nummask, dim=2) + bs = pos.shape[0] + scene_yaw = scene_yaw.type(torch.float) + s = torch.sin(scene_yaw).reshape(-1, 1, 1, 1) + c = torch.cos(scene_yaw).reshape(-1, 1, 1, 1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + world_xy = ((pos.unsqueeze(-2)) @ (rotM.transpose(-1, -2))).squeeze(-2) + world_xy += centroid.view(-1, 1, 1, 2).type(torch.float) + Mat = raster_from_world.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(world_xy, Mat) + agent_mask = mask.any(dim=2) + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii = torch.where(agent_mask[i])[0] + index[i] = ii + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, last_idx[i, ii]], + patches_size, + yaw[i, ii, last_idx[i, ii]], + ), + dim=-1, + ) + return ROI, index + else: + raise ValueError("mode must be 'all' or 'last'") + + +def generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size, + mode="last", +): + """ + This version generates ROI for all agents only at most recent time step unless specified otherwise + """ + if mode == "all": + bs = pos.shape[0] + yaw = yaw.type(torch.float) + Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(pos, Mat) + raster_mult = torch.linalg.norm( + raster_from_agent[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii, jj = torch.where(mask[i]) + index[i] = (ii, jj) + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, jj], + patches_size, + yaw[i, ii, jj], + ), + dim=-1, + ).to(pos.device) + return ROI, index + elif mode == "last": + num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) + nummask = num * mask + last_idx, _ = torch.max(nummask, dim=2) + bs = pos.shape[0] + Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(pos, Mat) + raster_mult = torch.linalg.norm( + raster_from_agent[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + agent_mask = mask.any(dim=2) + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii = torch.where(agent_mask[i])[0] + index[i] = ii + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, last_idx[i, ii]], + patches_size, + yaw[i, ii, last_idx[i, ii]], + ), + dim=-1, + ) + return ROI, index + else: + raise ValueError("mode must be 'all' or 'last'") + + +def Indexing_ROI_result(CNN_out, index, emb_size): + """put the lists of ROI align result into embedding tensor with the help of index""" + bs = len(CNN_out) + map_emb = torch.zeros(emb_size).to(CNN_out[0].device) + if map_emb.ndim == 3: + for i in range(bs): + map_emb[i, index[i]] = CNN_out[i] + elif map_emb.ndim == 4: + for i in range(bs): + ii, jj = index[i] + map_emb[i, ii, jj] = CNN_out[i] + else: + raise ValueError("wrong dimension for the map embedding!") + + return map_emb + + +def rasterized_ROI_align( + lane_mask, pos, yaw, raster_from_agent, mask, patch_size, out_dim +): + if pos.ndim == 4: + ROI, index = generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size.type(torch.float), + mode="all", + ) + lane_flags = ROI_align(lane_mask.unsqueeze(1), ROI, out_dim) + lane_flags = [x.mean([-2, -1]).view(x.size(0), 1) for x in lane_flags] + lane_flags = Indexing_ROI_result( + lane_flags, index, [*pos.shape[:3], 1]) + elif pos.ndim == 5: + lane_flags = list() + emb_size = (*pos[:, 0].shape[:-1], 1) + for i in range(pos.size(1)): + ROI, index = generate_ROIs( + pos[:, i], + yaw[:, i], + raster_from_agent, + mask, + patch_size.type(torch.float), + mode="all", + ) + lane_flag_i = ROI_align(lane_mask.unsqueeze(1), ROI, out_dim) + lane_flag_i = [x.mean([-2, -1]).view(x.size(0), 1) + for x in lane_flag_i] + lane_flags.append(Indexing_ROI_result( + lane_flag_i, index, emb_size)) + lane_flags = torch.stack(lane_flags, dim=1) + else: + raise ValueError("wrong shape") + return lane_flags + + +def obtain_map_enc( + image, + map_encoder, + pos, + yaw, + raster_from_agent, + mask, + patch_size, + output_size, + mode, +): + ROI, index = generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size, + mode, + ) + CNN_out = map_encoder(image, ROI) + if mode == "all": + emb_size = (*pos.shape[:-1], output_size) + elif mode == "last": + emb_size = (*pos.shape[:-2], output_size) + + # put the CNN output in the right location of the embedding + map_emb = Indexing_ROI_result(CNN_out, index, emb_size) + return map_emb + + +if __name__ == "__main__": + import numpy as np + from torchvision.ops.roi_align import RoIAlign + + + device = torch.device("cuda") + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + + # create feature layer, proposals and targets + num_proposals = 10 + + bs = 1 + features = torch.randn(bs, 10, 32, 32) + + xy = torch.rand((bs, 5, 2)) * torch.tensor([32, 32]) + WH = torch.ones((bs, 5, 1)) * torch.tensor([1, 1, 1, 1]).view(1, 1, -1) + psi = torch.zeros(bs, 5, 1) + ROI = torch.cat((xy, WH, psi), dim=-1) + ROI = [ROI[i] for i in range(ROI.shape[0])] + res1 = ROI_align(features, ROI, 6)[0].transpose(0, 1) + + ROI_star = torch.cat( + (xy - WH[..., [0, 2]], xy + WH[..., [1, 3]]), dim=-1)[0] + + roi_align_obj = RoIAlign(6, 1, sampling_ratio=2, aligned=False) + res2 = roi_align_obj(features, [ROI_star]) + + res1 - res2 diff --git a/diffstack/models/layers.py b/diffstack/models/layers.py new file mode 100644 index 0000000..848dff5 --- /dev/null +++ b/diffstack/models/layers.py @@ -0,0 +1,782 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + +"""Commonly used network layers and functions.""" + +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +def _pair(x): + if hasattr(x, "__iter__"): + return x + return (x, x) + + +def _triple(x): + if hasattr(x, "__iter__"): + return x + return (x, x, x) + + +def init_last_conv_layer(module, b_prior, w_std=0.01): + """Initializes parameters of a convolutional layer. + + Uses normal distribution for the kernel weights and constant + computed using bias_prior for the prior. + See Focal Loss paper for more details. + """ + for m in module.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d)): + nn.init.normal_(m.weight, mean=0, std=w_std) + nn.init.constant_(m.bias, -np.log((1.0 - b_prior) / b_prior)) + break + + +def soft_threshold(x, thresh, steepness=1024): + """Returns a soft-thresholded version of the tensor {x}. + + Elements less than {thresh} are set to zero, + elements appreciably larger than {thresh} are unchanged, + and elements approaching {thresh} from the positive side + steeply ramp to zero (faster for higher {steepness}). + + Behavior is very similar to th.threshold(x, thresh, 0), + but soft_threshold is composed of operations that are supported + by both the PyTorch 1.4 ONNX exporter and the TRT 6 ONNX parser. + + For comparison, th.threshold with nonzero thresh / value + cannot be exported by PyTorch 1.4, and other formulas + (like x * (x > thresh).float()) use operations + which are not supported by the TensorRT 6 importer. + """ + return x * ((x - thresh) * steepness).clamp(0, 1) + + +# noqa pylint: disable=R0911 +def Activation(act="elu"): + """Create activation function.""" + if act == "elu": + return nn.ELU(inplace=True) + if act == "smooth_elu": + return SmoothELU() + if act == "sigmoid": + return nn.Sigmoid() + if act == "tanh": + return nn.Tanh() + if act == "relu": + return nn.ReLU(inplace=True) + if act == "lrelu": + return nn.LeakyReLU(0.1, inplace=True) + if act is not None: + raise ValueError("Activation is not supported: {}.".format(act)) + return nn.Sequential() + + +def Normalization2D(norm, num_features): + """Create 2D layer normalization (4D inputs).""" + if norm == "bn": + return nn.BatchNorm2d(num_features) + if norm is not None: + raise ValueError("Normalization is not supported: {}.".format(norm)) + return nn.Sequential() + + +def Normalization3D(norm, num_features): + """Create 3D layer normalization (5D inputs).""" + if norm == "bn": + return nn.BatchNorm3d(num_features) + if norm is not None: + raise ValueError("Normalization is not supported: {}.".format(norm)) + return nn.Sequential() + + +def Upsample2D( + mode, in_channels, out_channels, kernel_size, stride, padding, output_padding +): + """Create upsampling layer for 4D input (2D convolution). + + Currently 3 interpolation modes are suported: interp, interp2, and deconv. + interp mode uses nearest neighbor interpolation + convolution. + interp2 mode also uses nearest neighbor interpolation + convolution, + but with a custom implementation that performs better in TRT 5 / 6. + deconv mode uses transposed convolution. + """ + + if stride == 1: + # Skip all upsampling + mode = "interp" + + if mode == "interp": + return UpsampleConv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + use_pth_interp=True, + ) + if mode == "interp2": + return UpsampleConv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + use_pth_interp=False, + ) + if mode == "deconv": + layer = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=stride + 2 * (stride // 2), + stride=stride, + padding=stride // 2, + output_padding=0, + ) + + # fix initialization to not have checkerboarding + with torch.no_grad(): + layer.weight.data[:] = layer.weight.data[:, :, :1, :1] * 0.5 + layer.weight.data[:, :, 1:-1, 1:-1] *= 2 + return layer + raise ValueError("Mode is not supported: {}".format(mode)) + + +def Upsample3D( + mode, in_channels, out_channels, kernel_size, stride, padding, output_padding +): + """Create upsampling layer for 5D input (3D convolution). + + Currently 2 interpolation modes are suported: interp and deconv. + interp mode uses nearest neighbor interpolation + convolution. + deconv mode uses transposed convolution. + """ + + if mode == "interp": + return UpsampleConv3D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + if mode == "deconv": + return nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + raise ValueError("Mode is not supported: {}".format(mode)) + + +class SmoothELU(nn.Module): + """Smooth version of ELU-like activation function. + + ELU derivative is continuous but not smooth. + See Improved Training of WGANs paper for more details. + """ + + def forward(self, x): + """Forward pass.""" + return F.softplus(2.0 * x + 2.0) / 2.0 - 1.0 + + +class UpsampleConv2D(nn.Module): + """Upsampling that uses nearest neighbor + 2D convolution.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + use_pth_interp=True, + ): + """Creates the upsampler. + + use_pth_interp forces using PyTorch interpolation (torch.nn.functional.interpolate) + when applicable, rather than custom interpolation implementation. + """ + super(UpsampleConv2D, self).__init__() + + stride = _pair(stride) + if len(stride) != 2: + raise ValueError( + "Stride must be either int or 2-tuple but got {}".format(stride) + ) + if stride[0] != stride[1]: + raise ValueError("H and W strides must be equal but got {}".format(stride)) + + self._scale_factor = stride[0] + self._use_pth_interp = use_pth_interp + + self.interp = ( + self.interpolation(self._scale_factor) if self._scale_factor > 1 else None + ) + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, # noqa: disable=E221 + stride=1, + padding=padding, + ) + + def interpolation(self, scale_factor=2, mode="nearest"): + """Returns interpolation module.""" + if self._use_pth_interp: + return nn.Upsample(scale_factor=scale_factor, mode=mode) + + return self._int_upsample + + def _int_upsample(self, x): + """Alternative implementation of nearest-neighbor interpolation. + + The main motivation is suboptimal performance of TRT NN interpolation (as of TRT 5/6) + for certain tensor dimensions. The implementation below is about 50% faster + than the current TRT implementation on V100/FP16. + """ + assert x.dim() == 4, "Expected NCHW tensor but got {}.".format(x.size()) + # in TRT have to limit to a basic case of 2x upsample. + assert ( + self._scale_factor == 2 + ), "Only scale factor == 2 is currently supported" " but got {}.".format( + self._scale_factor + ) + + n, c, h, w = [int(d) for d in x.size()] + # 1. Upsample in W dim. + # x = x.view(-1, 1).repeat(1, self._scale_factor) + x = x.reshape(n, c, -1, 1) + # Note: when using interp2, Alfred TRT 6.2 ONNX parser requires 4D tensors. + x = torch.cat((x, x), dim=-1) + # 2. Upsample in H dim. + x = x.reshape(n, c, h, w * 2) + # y = x.repeat(1, 1, 1, self._scale_factor) + y = torch.cat((x, x), dim=3) + y = y.reshape(n, c, h * 2, w * 2) + return y + + def forward(self, x): + """Forward pass.""" + res = self.interp(x) if self.interp is not None else x + return self.conv(res) + + +class UpsampleConv3D(nn.Module): + """Upsampling that uses nearest neighbor + 3D convolution.""" + + def __init__( + self, in_channels, out_channels, kernel_size, stride, padding, output_padding + ): + """Creates the upsampler.""" + super(UpsampleConv3D, self).__init__() + + stride = _pair(stride) + if len(stride) != 3: + raise ValueError( + "Stride must be either int or 3-tuple but got {}".format(stride) + ) + if stride[0] != 1: + raise ValueError( + "Upsampling in D dimension is not supported ({}).".format(stride) + ) + if stride[1] != stride[2]: + raise ValueError("H and W strides must be equal but got {}".format(stride)) + + self.interp = self.interpolation(stride[1]) if stride[1] > 1 else None + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, # noqa: disable=E221 + stride=1, + padding=padding, + ) + + def interpolation(self, scale_factor=2, mode="nearest"): + """Returns interpolation module.""" + return nn.Upsample(scale_factor=scale_factor, mode=mode) + + def forward(self, x): + """Forward pass.""" + res = x + if self.interp is not None: + res = self.interp( + x.view(x.size(0), x.size(1) * x.size(2), x.size(3), x.size(4)) + ) + res = res.view( + res.size(0), x.size(1), x.size(2), res.size(-2), res.size(-1) + ) + return self.conv(res) + + +class ASPPBlock(nn.Module): + """Adaptive spatial pyramid pooling block.""" + + def __init__( + self, + in_channels, + out_channels, + k_hw=3, + norm=None, + act="relu", + dilation_rates=(1, 6, 12), + act_before_combine=False, + act_after_combine=True, + combine_type="add", + combine_conv_k_hw=(3,), + ): + """Creates the ASPP block.""" + super().__init__() + + # Layers with varying dilation rate + self._aspp_layers = nn.ModuleList() + for rate in dilation_rates: + self._aspp_layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=k_hw, + padding=rate * (k_hw // 2), + dilation=(rate, rate), + ), + ) + + # Optional norm / activation before combining + if act_before_combine: + self._before_combine = nn.Sequential( + Normalization2D(norm, out_channels), + Activation(act), + ) + else: + self._before_combine = nn.Sequential() + + self._combine_type = combine_type + + # Layers to apply after combining + post_combine_layers = [] + if combine_type == "concat": + assert len(combine_conv_k_hw) > 0 + post_combine_layers.extend( + [ + nn.Conv2d( + out_channels * len(dilation_rates), + out_channels, + kernel_size=combine_conv_k_hw[0], + padding=combine_conv_k_hw[0] // 2, + ), + Normalization2D(norm, out_channels), + Activation(act), + ] + ) + combine_conv_k_hw = combine_conv_k_hw[1:] + + for c_c_k_hw in combine_conv_k_hw: + post_combine_layers.extend( + [ + nn.Conv2d( + out_channels, + out_channels, + kernel_size=c_c_k_hw, + padding=c_c_k_hw // 2, + ), + Normalization2D(norm, out_channels), + Activation(act), + ] + ) + if not act_after_combine and post_combine_layers: + # Remove the final activation + post_combine_layers = post_combine_layers[:-1] + + self._final_layer = nn.Sequential(*post_combine_layers) + + def _combine_op(self, xs): + if self._combine_type == "concat": + res = torch.cat(xs, -3) + elif self._combine_type == "add": + res = xs[0] + for x in xs[1:]: + res = res + x + else: + raise ValueError( + "Combine type {} is not supported.".format(self._combine_type) + ) + return res + + def forward(self, x): + """Forward pass.""" + aspp_outs = [] + for layer in self._aspp_layers: + aspp_outs.append(self._before_combine(layer(x))) + return self._final_layer(self._combine_op(aspp_outs)) + + +class ResnetBlock2D(nn.Module): + """Residual block with 2D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + s_hw=1, + bottleneck=True, + norm="bn", + act="relu", + ): + """Creates the residual block.""" + super(ResnetBlock2D, self).__init__() + + if k_hw < 3 or (k_hw % 2) == 0: + raise ValueError( + "ResnetBlock2D requires kernel size " + "to be odd and >= 3 but got {}".format(k_hw) + ) + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv2d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, dim), + Activation(act), + # 3x3 (or kxk) + nn.Conv2d(dim, dim, kernel_size=k_hw, stride=s_hw, padding=p_hw), + Normalization2D(norm, dim), + Activation(act), + # 1x1 expand. + nn.Conv2d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + nn.Conv2d( + in_channels, dim, kernel_size=k_hw, stride=s_hw, padding=p_hw + ), + Normalization2D(norm, dim), + Activation(act), + # Second 3x3 (or kxk). + nn.Conv2d(dim, out_channels, kernel_size=k_hw, stride=1, padding=p_hw), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=s_hw, padding=0 + ), + Normalization2D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlockTran2D(nn.Module): + """Transposed residual block with 2D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + s_hw=1, + bottleneck=True, + norm="bn", + act="relu", + upsample="interp", + ): + """Creates the transposed residual block.""" + super(ResnetBlockTran2D, self).__init__() + + if k_hw < 3 or (k_hw % 2) == 0: + raise ValueError( + "ResnetBlockTran2D requires kernel size " + "to be odd and >= 3 but got {}".format(k_hw) + ) + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 + o_p_hw = p_hw if s_hw > 1 else 0 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv2d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, dim), + Activation(act), + # 3x3 (or kxk) + Upsample2D( + upsample, + dim, + dim, + kernel_size=k_hw, + stride=s_hw, + padding=p_hw, + output_padding=o_p_hw, + ), + Normalization2D(norm, dim), + Activation(act), + # 1x1 expand. + nn.Conv2d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + Upsample2D( + upsample, + in_channels, + dim, + kernel_size=k_hw, + stride=s_hw, + padding=p_hw, + output_padding=o_p_hw, + ), + Normalization2D(norm, dim), + Activation(act), + # Second 3x3 (or kxk). + nn.Conv2d(dim, out_channels, kernel_size=k_hw, stride=1, padding=p_hw), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1: + self.shortcut = nn.Sequential( + Upsample2D( + upsample, + in_channels, + out_channels, + kernel_size=1, + stride=s_hw, + padding=0, + output_padding=o_p_hw, + ), + Normalization2D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlock3D(nn.Module): + """Residual block with 3D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + k_d=1, + s_hw=1, + s_d=1, + bottleneck=True, + norm="bn", + act="relu", + ): + """Creates the residual block.""" + super(ResnetBlock3D, self).__init__() + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 # noqa: disable=E221 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv3d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, dim), + self.act, + # 3x3 (or kxk) + nn.Conv3d( + dim, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, dim), + self.act, + # 1x1 expand. + nn.Conv3d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + nn.Conv3d( + in_channels, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, dim), + self.act, + # Second 3x3 (or kxk). + nn.Conv3d( + dim, + out_channels, + kernel_size=(k_d, k_hw, k_hw), + stride=1, + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1 or s_d > 1: + self.shortcut = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=(s_d, s_hw, s_hw), + padding=0, + ), + Normalization3D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) # noqa: disable=E221 + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlockTran3D(nn.Module): + """Transposed residual block with 3D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + k_d=1, + s_hw=1, + s_d=1, + bottleneck=True, + norm="bn", + act="relu", + upsample="interp", + ): + """Creates the transposed residual block.""" + super(ResnetBlockTran3D, self).__init__() + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 # noqa: disable=E221 + o_p_hw = p_hw if s_hw > 1 else 0 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv3d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, dim), + self.act, + # 3x3 (or kxk) + Upsample3D( + upsample, + dim, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, dim), + self.act, + # 1x1 expand. + nn.Conv3d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + Upsample3D( + upsample, + in_channels, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, dim), + self.act, + # Second 3x3 (or kxk). + nn.Conv3d( + dim, + out_channels, + kernel_size=(k_d, k_hw, k_hw), + stride=1, + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1 or s_d > 1 or k_d > 1: + self.shortcut = nn.Sequential( + Upsample3D( + upsample, + in_channels, + out_channels, + kernel_size=(k_d, 1, 1), + stride=(s_d, s_hw, s_hw), + padding=(0, 0, 0), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) # noqa: disable=E221 + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) diff --git a/diffstack/models/learned_metrics.py b/diffstack/models/learned_metrics.py new file mode 100644 index 0000000..079ae89 --- /dev/null +++ b/diffstack/models/learned_metrics.py @@ -0,0 +1,85 @@ +from typing import Dict + +import torch +import torch.nn as nn + +import diffstack.models.base_models as base_models +import diffstack.utils.tensor_utils as TensorUtils + + +class PermuteEBM(nn.Module): + """Raster-based model for planning. + """ + + def __init__( + self, + model_arch: str, + input_image_shape, + map_feature_dim: int, + traj_feature_dim: int, + embedding_dim: int, + embed_layer_dims: tuple + ) -> None: + + super().__init__() + self.map_encoder = base_models.RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=map_feature_dim, + use_spatial_softmax=False, + output_activation=nn.ReLU + ) + self.traj_encoder = base_models.RNNTrajectoryEncoder( + trajectory_dim=3, + rnn_hidden_size=100, + feature_dim=traj_feature_dim + ) + self.embed_net = base_models.MLP( + input_dim=traj_feature_dim + map_feature_dim, + output_dim=embedding_dim, + output_activation=nn.ReLU, + layer_dims=embed_layer_dims + ) + self.score_net = nn.Linear(embedding_dim, 1) + + def forward(self, data_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + image_batch = data_batch["image"] + trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) + bs = image_batch.shape[0] + + map_feat = self.map_encoder(image_batch) # [B, D_m] + traj_feat = self.traj_encoder(trajs) # [B, D_t] + + # construct contrastive samples + map_feat_rep = TensorUtils.unsqueeze_expand_at(map_feat, size=bs, dim=1) # [B, B, D_m] + traj_feat_rep = TensorUtils.unsqueeze_expand_at(traj_feat, size=bs, dim=0) # [B, B, D_t] + cat_rep = torch.cat((map_feat_rep, traj_feat_rep), dim=-1) # [B, B, D_m + D_t] + ebm_rep = TensorUtils.time_distributed(cat_rep, self.embed_net) # [B, B, D] + + # calculate embeddings and scores for InfoNCE loss + scores = TensorUtils.time_distributed(ebm_rep, self.score_net).squeeze(-1) # [B, B] + out_dict = dict(features=ebm_rep, scores=scores) + + return out_dict + + def get_scores(self, data_batch): + image_batch = data_batch["image"] + trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) + + map_feat = self.map_encoder(image_batch) # [B, D_m] + traj_feat = self.traj_encoder(trajs) # [B, D_t] + cat_rep = torch.cat((map_feat, traj_feat), dim=-1) # [B, D_m + D_t] + ebm_rep = self.embed_net(cat_rep) + scores = self.score_net(ebm_rep) + out_dict = dict(features=ebm_rep, scores=scores) + + return out_dict + + def compute_losses(self, pred_batch, data_batch): + scores = pred_batch["scores"] + bs = scores.shape[0] + labels = torch.arange(bs).to(scores.device) + loss = nn.CrossEntropyLoss()(scores, labels) + losses = dict(infoNCE_loss=loss) + + return losses \ No newline at end of file diff --git a/diffstack/models/unet.py b/diffstack/models/unet.py new file mode 100644 index 0000000..2ef73fc --- /dev/null +++ b/diffstack/models/unet.py @@ -0,0 +1,512 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# adapted from predictionnet by Nvidia +from diffstack.models.layers import ( + Activation, + Normalization2D, + Upsample2D) +class UResBlock(nn.Module): + """UNet ResBlock. + + Uses dense 3x3 convs only, single resize *after* skip, + and other graph tweaks specific for UNet. + """ + + def __init__(self, in_channels, out_channels, stride, upsample=None, + norm='bn', act='relu', out_act='relu'): + """Construct UNet ResBlock.""" + super().__init__() + self.act = Activation(out_act) + self.block = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 3, padding=1), + Normalization2D(norm, in_channels), + Activation(act), + nn.Conv2d(in_channels, in_channels, 3, padding=1), + Activation(act), + nn.Conv2d(in_channels, in_channels, 3, padding=1), + ) + self.resize = nn.Conv2d(in_channels, out_channels, 3, padding=1) + if stride > 1: + if upsample is not None: + self.resize = Upsample2D( + upsample, + in_channels, + out_channels, + stride * 2 - 1, + stride=stride, + padding=stride // 2, + output_padding=0 + ) + else: + self.resize = nn.Conv2d( + in_channels, + out_channels, + stride * 2 - 1, + stride=stride, + padding=stride // 2 + ) + + def forward(self, x): + """Forward pass.""" + x = self.block(x) + x + return self.act(self.resize(x)) + +class UNet(nn.Module): + """UNet model.""" + + def __init__(self, in_frames, out_frames, in_channels, out_channels, channels, strides, + decoder_strides=None, skip_type='add', upsample='interp', norm='bn', + activation='relu', last_layer_prior=0.0001, last_layer_mvec_bias=0.0, + last_layer_vel_bias=0.0, enc_out_dropout=0.0, dec_in_dropout=0.0, + trajectory_candidates=0, enable_nll=False, desired_size=None, **kwargs): + """Initialize the model. + + Args: + in_frames: number of input frames. + out_frames: number of output frames. + in_channels: list of [static input channels, dynamic input channels]. + out_channels: list of [output 0 channels, ..., output N channels] for raster outputs. + channels: number of channels in convolutional layers specified with a list. + strides: encoder's convolutional strides specified with a list. + decoder_strides: decoder's convolutional strides specified with a list. + channels, strides, and decoder strides are specified with the same sized list. + skip_type: connection between corresponding encoder/decoder blocks. + Can be add (default), none. + upsample: upsampling type in decoder. Passed to mode parameter of layers.Upsample2D. + norm: normalization type. Can be batch norm (bn) or none. + activation: activation function for all layers. See layers.Activation. + last_layer_prior: last layer bias initialization prior to work properly with + focal loss. + last_layer_mvec_bias: last layer motion vector output initial bias. + last_layer_vel_bias: last layer velocity output initial bias. + enc_out_dropout: encoder output dropout. + dec_in_dropout: decoder input dropout. + trajectory_candidates: number of trajectories to regress directly. + enable_nll: regress output for nll loss if true + desired_size: desired size for the input + """ + super(UNet, self).__init__() + self._in_frames = in_frames # noqa: disable=E221 + self._out_frames = out_frames + self._out_channels = out_channels + if isinstance(out_channels, int): + self._out_channels = [out_channels] + self._skip_type = self._get_str_arg(skip_type, 'skip_type', ['add', 'none']) + self._norm = self._get_str_arg(norm, 'norm', ['bn', 'none']) + self._act = self._get_str_arg(activation, 'activation', ['lrelu', 'relu', 'elu', + 'smooth_elu', 'sigmoid', + 'tanh', 'none']) + self._channels = list(channels) + self._strides = list(strides) # noqa: disable=E221 + if decoder_strides is None: + # symmetrical case, use enc. strides + self._dec_strides = list(reversed(self._strides)) + else: + self._dec_strides = list(decoder_strides) + if len(self._channels) != len(self._strides): + raise ValueError('Number of channels {} must be equal to number of ' + 'strides {}'.format(len(self._channels), len(self._strides))) + self._num_resnet_blocks = len(self._channels) - 1 + static_in_channels, dynamic_in_channels = in_channels + self._in_channels = dynamic_in_channels * in_frames + static_in_channels + # Encoder. + self._e0 = nn.Sequential( + nn.Conv2d(self._in_channels, self._channels[0], kernel_size=3, + stride=self._strides[0], padding=1), + Activation(self._act) + ) + for i in range(1, self._num_resnet_blocks + 1): + block = UResBlock( + self._channels[i - 1], + self._channels[i], + stride=self._strides[i], + upsample=None, + norm=self._norm, + act=self._act, + out_act=self._act + ) + setattr(self, '_e{}'.format(i), block) + self._enc_drop = nn.Dropout2d(p=enc_out_dropout) + self._dec_drop = nn.Dropout2d(p=dec_in_dropout) + output_per_frame = 2 # acce, yaw rate + if enable_nll: + output_per_frame += 3 # log_stdx, log_stdy, rho + total_out_channels = ( + sum(self._out_channels) * out_frames + # BEV Image outputs + trajectory_candidates * out_frames * output_per_frame # Trajectory outputs + ) + # Decoder. + min_last_dec_channel = 64 + for i in range(1, self._num_resnet_blocks + 1): + dec_in_channels_i = self._channels[i] + dec_out_channels_i = self._channels[i - 1] + if i == 1: + # Last decoder channel doesn't need to match encoder because there's no skip + # connection after + dec_out_channels_i = max(dec_out_channels_i, min_last_dec_channel) + block = UResBlock( + dec_in_channels_i, + dec_out_channels_i, + stride=self._dec_strides[self._num_resnet_blocks - i], + upsample=upsample, + norm=self._norm, + act=self._act, + out_act=None + ) + setattr(self, '_d{}'.format(i), block) + self._dec_act = Activation(self._act) + self._trajectory_candidates = trajectory_candidates + # Output layers. + # These are combined into one for efficiency; + # each output head just slices the combined result + output_layer = nn.Sequential( + Upsample2D(upsample, max(min_last_dec_channel, self._channels[0]), + total_out_channels, + kernel_size=3, stride=self._dec_strides[-1], + padding=1, output_padding=0) + ) + # Focal loss init for last layer occupancy channel + last_bias_layer = [m for m in output_layer.modules() if hasattr(m, 'bias')][-1] + with torch.no_grad(): + # most channels are zero-initialized + last_bias_layer.bias.data[:] = 0 + # channels corresponding to occupancy get focal loss initialization + occ_bias = -torch.log((torch.ones(1) - last_layer_prior) / last_layer_prior) + last_bias_layer.bias.data[:out_frames] = occ_bias + # Initialize biases for mvecs and velocity outputs. + # If there are 2 or more outputs, mvecs will be in the last 2 channels. + if len(out_channels) > 1: + last_bias_layer.bias.data[-2 * out_frames:] = last_layer_mvec_bias + # If there are 3 or more outputs, velocity will be in the second group of channels. + if len(out_channels) > 2: + last_bias_layer.bias.data[1*out_frames:3 * out_frames] = last_layer_vel_bias + # Final outputs will be slices of _d0_core + setattr(self, '_d0_core', output_layer) + self.desired_size = desired_size + + @property + def out_channels(self): + """Returns output channels configuration.""" + return self._out_channels + + @staticmethod + def _get_str_arg(val, name, allowed_values): + if val not in allowed_values: + raise ValueError('{} has invalid value {}. Supported values: {}'.format( + name, val, ', '.join(allowed_values))) + if val == 'none': + val = None + return val + + def _skip_op(self, enc, dec): + if self._skip_type == 'add': + res = enc + dec + elif self._skip_type is None: + res = dec + else: + raise ValueError('Skip type {} is not supported.'.format(self._skip_type)) + return res + + def _get_skip_key_for_tensor(self, x): + """Get a key for determining valid skip connections. + + Args: + x: torch Tensor of NCHW data to be used in skip connection. + Returns a key k(x) such that, if k(x) == k(y), then + self._skip_op(x, y) is a valid skip connection. + This is useful when the encoder / decoder are not using matching + block counts / filter counts / strides, and symmetric UNet + skip connections based only on the block index won't work. + Example usage for a simple case: + skips = {} + # during encoder(x) + skips[self._get_skip_key_for_tensor(x)] = x + # ... + # during decoder(y) + y_key = self._get_skip_key_for_tensor(y) + if y_key in skips: + y = self._skip_op(skips.pop(y_key), y) # pop() to prevent reuse + """ + assert x.ndim == 4, f'Invalid {x.shape}' + # N is assumed to always match, + # and H is assumed to match if W does + # (also - broadcasting between H / TH is allowed), + # so create a key just based on C and W. + return (int(x.size(1)), int(x.size(-1))) + + def _apply_encoder(self, x, s): + """Run the encoder tower on a batch of inputs. + + Args: + x: N[CT]HW tensor, representing N episodes of T timesteps of + C channels at HW spatial locations (aka dynamic context). + s: NCHW tensor, representing N episodes of C channels + at HW spatial locations (aka static context). + Returns a tuple ( + a N[CT]HW tensor, + containing final output features from the encoder, + a dictionary of {size: NCHW tensor}, + containing intermediate features from each encoder block; + these are only for the first timestep, and can therefore + be safely used in decoder "skip connections" for all timesteps + without leaking any information backwards in time. + ) + """ + # Check sizes are reasonable + assert ( + x.ndim == 4 + ), f'Invalid x {x.size()}, should be N[CT]HW' + x = torch.cat([s, x], 1) + e_in = self._e0(x) + # Run all encoder blocks + skip_out_dict = {} + # import pdb + # pdb.set_trace() + for i in range(self._num_resnet_blocks): + block = getattr(self, '_e{}'.format(i + 1)) + e_out = block(e_in) + e_in = e_out # noqa: disable=E221 + skip_out_key = self._get_skip_key_for_tensor(e_out) + skip_out_dict[skip_out_key] = skip_out_dict.get(skip_out_key, []) + [e_out] + # Apply dropout on encoder output + if self._enc_drop.p > 0: + e_out = self._enc_drop(e_out) + return e_out, skip_out_dict + + def _apply_decoder(self, x, skip_connections): + """Run the decoder tower on a batch of inputs. + + Args: + x: N[CT]HW tensor, representing N episodes of T timesteps of + C channels at HW spatial locations (aka dynamic context). + skip_connections: a dictionary of {size: NCHW tensor}, + representing output context features from each encoder stage. + Returns an N[CT]HW tensor, + containing decoder output. Output "heads" should + be sliced from the channels of this tensor. + """ + d_in = x + # Apply dropout on decoder input + if self._dec_drop.p > 0: + d_in = self._dec_drop(d_in) + # Run all decoder blocks, applying skip connections + dec_skip_keys = [] + for i in range(self._num_resnet_blocks): + # Apply skip connection + skip_key = self._get_skip_key_for_tensor(d_in) + dec_skip_keys.append(skip_key) + if skip_key in skip_connections and len(skip_connections[skip_key]) > 0: + skip = skip_connections[skip_key].pop() + # do not add the same thing to itself! + if d_in is not skip: + d_in = self._dec_act(self._skip_op(skip, d_in)) + else: + print('failed skip for', skip_key, 'in', skip_connections.keys()) + # Apply block + block = getattr(self, '_d{}'.format(self._num_resnet_blocks - i)) + d_out = block(d_in) + d_in = d_out # noqa: disable=E221 + # Apply shared decoder layer + d_out = self._d0_core(self._dec_act(d_out)) + return d_out + + def is_recurrent(self): + """Check if model is recurrent.""" + return False + + def forward(self, x): + """Forward pass.""" + + # The input contains separate static / dynamic tensors + assert len(x) == 2, f'Invalid x of type {type(x)} length {len(x)}' + s, x = x # Separate static and dynamic parts + assert ( + s.shape[-2:] == x.shape[-2:] + ), f'Invalid static / dynamic shapes: ({s.shape}, {x.shape})' + # In general, 5D tensors are consumed during training, while 4D - during inference. + if x.dim() not in [4, 5]: + raise ValueError(f'Expected 4D(N[CT]HW) or 5D(NCTHW) tensor but got {x.dim()}') + if self.desired_size is not None: + assert s.shape[-2]<=self.desired_size[0] and s.shape[-1]<=self.desired_size[1] + w,h = s.shape[-2:] + pw,ph = self.desired_size[0]-w,self.desired_size[1]-h + pad_w = torch.zeros([*s.shape[:-2],pw,h],dtype=s.dtype,device=s.device) + pad_h = torch.zeros([*s.shape[:-2],self.desired_size[0],ph],dtype=s.dtype,device=s.device) + s = torch.cat((torch.cat((s,pad_w),-2),pad_h),-1) + + pad_w = torch.zeros([*x.shape[:-2],pw,h],dtype=x.dtype,device=x.device) + pad_h = torch.zeros([*x.shape[:-2],self.desired_size[0],ph],dtype=x.dtype,device=x.device) + x = torch.cat((torch.cat((x,pad_w),-2),pad_h),-1) + + + # Save input H & W resolution + + dim_h_input, dim_w_input = x.size()[-2:] + dim_n = int(x.size(0)) + input_is_5d = x.dim() == 5 + if input_is_5d: + # Transform from NCTHW to N[CT]HW. + s = s.view(dim_n, -1, dim_h_input, dim_w_input) + x = x.view(dim_n, -1, dim_h_input, dim_w_input) + + # At this point, everything is 4D N[CT]HW + # (regardless of training or export or whatever); + # we'll convert back to NCTHW at the end if input_is_5d is True + assert ( + s.size(1) + x.size(1) == self._in_channels + ), f'Channel counts in input {s.size(), x.size()} do not match model' + + # Apply encoder / decoder + enc_past, enc_past_skip = self._apply_encoder(x, s=s) + res = self._apply_decoder(enc_past, enc_past_skip) + + # During export - don't do any resizing or slicing; + # those steps will be handled by the DW inference code + if torch.onnx.is_in_onnx_export(): + return res + + # During training, we upsample and slice as needed + res = F.interpolate( + res, + size=(dim_h_input, dim_w_input), + mode='bilinear', + align_corners=False + ) + + # Slice outputs for each output head + res_sliced = [] + slice_start = 0 + dim_t_out = self._out_frames + dim_h_out, dim_w_out = (int(d) for d in res.size()[-2:]) + for dim_c_i in self._out_channels: + slice_end = slice_start + dim_c_i * dim_t_out + res_sliced_i = res[:, slice_start:slice_end] + slice_start = slice_end + # Slice is, again, N[CT]HW + assert res_sliced_i.size() == (dim_n, dim_c_i * dim_t_out, dim_h_out, dim_w_out) + if input_is_5d: + # But, we convert the slice back to NCTHW (5d) if needed + res_sliced_i = res_sliced_i.view(dim_n, dim_c_i, dim_t_out, dim_h_out, dim_w_out) + res_sliced.append(res_sliced_i) + + # Optional trajectory regression output + if self._trajectory_candidates > 0: + # fill in Nones for any other outputs + res_sliced.append(res[:, slice_end:]) + return tuple(res_sliced) + + +class DirectRegressionPostProcessor(torch.nn.Module): + """Direct-regression trajectory extractor. + + This provides a differentiable module for converting a dense DNN output + tensor into a trajectories, given initial positions. + """ + + def __init__( + self, + num_in_frames, + image_height_px, + image_width_px, + sampling_time, + trajectory_candidates, + dyn + ): + """Initialize the post processor. + + Args: + num_in_frames: the value of in_frames used for the predictions. + image_height_px: the vertical spatial resolution of tensors to expect + in postprocessor forward() - i.e. output resolution of prediction DNN. + image_width_px: the corresponding horizontal spatial resolution. + frame_duration_s: the time delta (seconds) between successive input / output frames + trajectory_candidates: number of trajectory candidates to extract. + """ + + super().__init__() + self._num_in_frames = num_in_frames + self._image_height_px = image_height_px + self._image_width_px = image_width_px + assert ( + self._image_height_px == self._image_width_px + ), f"Assumed square pixels, but {self._image_height_px} != {self._image_width_px}" + + self._dt_s = sampling_time + assert self._dt_s > 0, f"Invalid frame duration (s): {self._dt_s}" + self._dyn = dyn + self._trajectory_candidates = trajectory_candidates + + def forward(self, obj_trajectory, query_pts, curr_state, actual_size=None): + """Convert a predicted trajectory tensor and initial positions into a list of FrameTrajectories. + + Args: + obj_trajectory: Torch tensor of dimensions [N, C, H, W] representing + predicted trajectories at each location. For K candidates, C = K * (1 + T * 2), + corresponding to 1 overall confidence score and T (ax, ay) acceleration pairs. + query_pts: location of the agents on the image [N,Na,2] + + Returns a length-N list of FrameTrajectories for each episode in the batch. + """ + bs,Na = query_pts.shape[:2] + if actual_size is None: + pos_xy_rel = query_pts.unsqueeze(1)/(obj_trajectory.shape[-1]-1)*2-1 + else: + pos_xy_rel = query_pts.unsqueeze(1)/(actual_size-1)*2-1 + + pred_per_agent = ( + torch.nn.functional.grid_sample( + obj_trajectory, + pos_xy_rel, + mode="bilinear", + align_corners=True, + ) + .squeeze(2) + .transpose(1,2) + ) + + input_pred = pred_per_agent.reshape( + bs*Na, self._trajectory_candidates, -1, 2 + ) # AxTxCandidatesx2 + traj = self._dyn.forward_dynamics(curr_state.reshape(bs*Na,1,-1).repeat_interleave(self._trajectory_candidates,1),input_pred) + pos = self.dyn.state2pos(traj) + yaw = self.dyn.state2yaw(traj) + + return traj,pos,yaw,input_pred + + +def main(): + num_in_frames = 10 + num_out_frames = 10 + model = UNet(in_frames=num_in_frames, out_frames=num_out_frames, in_channels=[2,3], + out_channels=[1], channels=[32, 64, 128, 128, 256, 256, 256], + strides=[2, 2, 2, 2, 2, 2, 2], decoder_strides=[2, 2, 2, 2, 2, 1, 1], + skip_type='add', upsample='interp', norm='bn', + activation='relu', last_layer_prior=0.0001, last_layer_mvec_bias=0.0, + last_layer_vel_bias=0.0, enc_out_dropout=0.0, dec_in_dropout=0.0, + trajectory_candidates=2, enable_nll=False) + + static_input = torch.ones([10,2,255,255]) + dynamic_input = torch.zeros([10,3,10,255,255]) + out = model((static_input,dynamic_input)) + query_pts = torch.rand([10,5,2])*255 + from diffstack.dynamics import Unicycle + dyn = Unicycle(0.1, vbound=[-1, 40.0]) + curr_state=torch.randn([10,5,4]) + + postprocessor = DirectRegressionPostProcessor(num_in_frames, + 255, + 255, + 0.1, + 2, + dyn, + ) + pred_logits, pred_traj = out + traj = postprocessor(pred_traj,query_pts,curr_state) + import pdb + pdb.set_trace() + + + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/diffstack/models/vaes.py b/diffstack/models/vaes.py new file mode 100644 index 0000000..4747a32 --- /dev/null +++ b/diffstack/models/vaes.py @@ -0,0 +1,1348 @@ +"""Variants of Conditional Variational Autoencoder (C-VAE)""" +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffstack.utils.loss_utils import KLD_0_1_loss, KLD_gaussian_loss, KLD_discrete +from diffstack.utils.torch_utils import reparameterize +from diffstack.models.base_models import MLP, SplitMLP +import diffstack.utils.tensor_utils as TensorUtils +from torch.distributions import Categorical, kl_divergence + + +class Prior(nn.Module): + def __init__(self, latent_dim, input_dim=None, device=None): + """ + A generic prior class + Args: + latent_dim: dimension of the latent code (e.g., mu, logvar) + input_dim: (Optional) dimension of the input feature vector, for conditional prior + device: + """ + super(Prior, self).__init__() + self._latent_dim = latent_dim + self._input_dim = input_dim + self._net = None + self._device = device + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + raise NotImplementedError + + @staticmethod + def get_mean(prior_params): + """ + Extract the "mean" of a prior distribution (not supported by all priors) + + Args: + prior_params (torch.Tensor): a batch of prior parameters + + Returns: + mean (torch.Tensor): the "mean" of the distribution + """ + raise NotImplementedError + + def sample(self, n: int, inputs: torch.Tensor = None): + """ + Take samples with the prior distribution + + Args: + n (int): number of samples to take + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [input_batch_size, n, latent_dim] + """ + prior_params = self.forward(inputs=inputs) + return self.sample_with_parameters(prior_params, n=n) + + @staticmethod + def sample_with_parameters(params: dict, n: int): + """ + Take samples using given a batch of distribution parameters, e.g., mean and logvar of a unit Gaussian + + Args: + params (dict): a batch of distribution parameters + n (int): number of samples to take + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [param_batch_size, n, latent_dim] + """ + raise NotImplementedError + + def kl_loss(self, posterior_params: dict, inputs: torch.Tensor = None) -> torch.Tensor: + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + raise NotImplementedError + + @property + def posterior_param_shapes(self) -> dict: + """ + Shape of the posterior parameters + + Returns: + shapes (dict): a dictionary of parameter shapes + """ + raise NotImplementedError + + @property + def latent_dim(self): + """ + Shape of the latent code + + Returns: + latent_dim (int) + """ + return self._latent_dim + + +class FixedGaussianPrior(Prior): + """An unassuming unit Gaussian Prior""" + def __init__(self, latent_dim, input_dim=None, device=None): + super(FixedGaussianPrior, self).__init__( + latent_dim=latent_dim, input_dim=input_dim, device=device) + self._params = nn.ParameterDict({ + "mu": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=False), + "logvar": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=False) + }) + + @staticmethod + def get_mean(prior_params): + return prior_params["mu"] + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + + batch_size = 1 if inputs is None else inputs.shape[0] + params = TensorUtils.unsqueeze_expand_at(self._params, size=batch_size, dim=0) + return params + + @staticmethod + def sample_with_parameters(params, n: int): + """ + Take samples using given a batch of distribution parameters, e.g., mean and logvar of a unit Gaussian + + Args: + params (dict): a batch of distribution parameters + n (int): number of samples to take + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [param_batch_size, n, latent_dim] + """ + + batch_size = params["mu"].shape[0] + params_tiled = TensorUtils.repeat_by_expand_at(params, repeats=n, dim=0) + samples = reparameterize(params_tiled["mu"], params_tiled["logvar"]) + samples = TensorUtils.reshape_dimensions(samples, begin_axis=0, end_axis=1, target_dims=(batch_size, n)) + return samples + + def kl_loss(self, posterior_params, inputs=None): + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + + assert posterior_params["mu"].shape[1] == self._latent_dim + assert posterior_params["logvar"].shape[1] == self._latent_dim + return KLD_0_1_loss( + mu=posterior_params["mu"], + logvar=posterior_params["logvar"] + ) + + @property + def posterior_param_shapes(self) -> OrderedDict: + return OrderedDict(mu=(self._latent_dim,), logvar=(self._latent_dim,)) + + +class ConditionalCategoricalPrior(Prior): + """ + A class that holds functionality for learning categorical priors for use + in VAEs. + """ + def __init__(self, latent_dim, input_dim=None, device=None): + """ + Args: + latent_dim (int): size of latent dimension for the prior + device (torch.Device): where the module should live (i.e. cpu, gpu) + """ + super(ConditionalCategoricalPrior, self).__init__(latent_dim=latent_dim, input_dim=input_dim, device=device) + assert isinstance(input_dim, int) and input_dim > 0 + self.device = device + self._latent_dim = latent_dim + self._prior_net = MLP(input_dim=input_dim, output_dim=latent_dim) + + def sample(self, n: int, inputs: torch.Tensor = None): + """ + Returns a batch of samples from the prior distribution. + Args: + n (int): this argument is used to specify the number + of samples to generate from the prior. + inputs (torch.Tensor): conditioning feature for prior + Returns: + z (torch.Tensor): batch of sampled latent vectors. + """ + + # check consistency between n and obs_dict + if self.learnable: + + # forward to get parameters + out = self.forward(batch_size=n, obs_dict=obs_dict, goal_dict=goal_dict) + prior_logits = out["logit"] + + # sample one-hot latents from categorical distribution + dist = Categorical(logits=prior_logits) + z = TensorUtils.to_one_hot(dist.sample(), num_class=self.categorical_dim) + + else: + # try to include a categorical sample for each class if possible (ensuring rough uniformity) + if (self.latent_dim == 1) and (self.categorical_dim <= n): + # include samples [0, 1, ..., C - 1] and then repeat until batch is filled + dist_samples = torch.arange(n).remainder(self.categorical_dim).unsqueeze(-1).to(self.device) + else: + # sample one-hot latents from uniform categorical distribution for each latent dimension + probs = torch.ones(n, self.latent_dim, self.categorical_dim).float().to(self.device) + dist_samples = Categorical(probs=probs).sample() + z = TensorUtils.to_one_hot(dist_samples, num_class=self.categorical_dim) + + # reshape [B, D, C] to [B, D * C] to be consistent with other priors that return flat latents + z = z.reshape(*z.shape[:-2], -1) + return z + + def kl_loss(self, posterior_params, inputs = None): + """ + Computes KL divergence loss between the Categorical distribution + given by the unnormalized logits @logits and the prior distribution. + Args: + posterior_params (dict): dictionary with key "logits" corresponding + to torch.Tensor batch of unnormalized logits of shape [B, D * C] + that corresponds to the posterior categorical distribution + Returns: + kl_loss (torch.Tensor): KL divergence loss + """ + prior_logits = self._prior_net(inputs) + + prior_dist = Categorical(logits=prior_logits) + posterior_dist = Categorical(logits=posterior_params["logits"]) + + # sum over latent dimensions, but average over batch dimension + kl_loss = kl_divergence(posterior_dist, prior_dist) + assert len(kl_loss.shape) == 2 + return kl_loss.sum(-1).mean() + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + prior_logits = self._prior_net(inputs) + return prior_logits + + +class LearnedGaussianPrior(FixedGaussianPrior): + """A Gaussian prior with learnable parameters""" + def __init__(self, latent_dim, input_dim=None, device=None): + super(LearnedGaussianPrior, self).__init__( + latent_dim=latent_dim, input_dim=input_dim, device=device) + self._params = nn.ParameterDict({ + "mu": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=True), + "logvar": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=True) + }) + + def kl_loss(self, posterior_params, inputs=None): + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + + assert posterior_params["mu"].shape[1] == self._latent_dim + assert posterior_params["logvar"].shape[1] == self._latent_dim + + batch_size = posterior_params["mu"].shape[0] + prior_params = TensorUtils.unsqueeze_expand_at(self._params, size=batch_size, dim=0) + return KLD_gaussian_loss( + mu_1=posterior_params["mu"], + logvar_1=posterior_params["logvar"], + mu_2=prior_params["mu"], + logvar_2=prior_params["logvar"] + ) + + +class CVAE(nn.Module): + def __init__( + self, + q_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + prior: Prior + ): + """ + A basic Conditional Variational Autoencoder Network (C-VAE) + + Args: + q_net (nn.Module): a model that encodes data (x) and condition inputs (x_c) to posterior (q) parameters + c_net (nn.Module): a model that encodes condition inputs (x_c) into condition feature (c) + decoder (nn.Module): a model that decodes latent (z) and condition feature (c) to data (x') + prior (nn.Module): a model containing information about distribution prior (kl-loss, prior params, etc.) + """ + super(CVAE, self).__init__() + self.q_net = q_net + self.c_net = c_net + self.decoder = decoder + self.prior = prior + + def sample(self, condition_inputs, n: int, condition_feature=None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + z = self.prior.sample(n=n, inputs=c) # z of shape [B (from c), N, ...] + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * N, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * N, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(c.shape[0], n)) + return x_out + + def predict(self, condition_inputs, condition_feature=None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + + prior_params = self.prior(c) # [B, ...] + mu = self.prior.get_mean(prior_params) # [B, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=mu, condition_features=c, **decoder_kwargs) + return x_out + + def forward(self, inputs, condition_inputs, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + c = self.c_net(condition_inputs) # [B, ...] + q_params = self.q_net(inputs=inputs, condition_features=c) + z = self.prior.sample_with_parameters(q_params, n=1).squeeze(dim=1) + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z, condition_features=c, **decoder_kwargs) + return {"x_recons": x_out, "q_params": q_params, "z": z, "c": c} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + return self.prior.kl_loss(outputs["q_params"], inputs=outputs["c"]) + + +class DiscreteCVAE(nn.Module): + def __init__( + self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + K: int, + recon_loss_fun=None, + logpi_clamp = None, + ): + """ + A basic Conditional Variational Autoencoder Network (C-VAE) + + Args: + q_net (nn.Module): a model that encodes data (x) and condition inputs (x_c) to posterior (q) parameters + p_net (nn.Module): a model that encodes condition feature (c) to latent (p) parameters + c_net (nn.Module): a model that encodes condition inputs (x_c) into condition feature (c) + decoder (nn.Module): a model that decodes latent (z) and condition feature (c) to data (x') + K (int): cardinality of the discrete latent + recon_loss: loss function handle for reconstruction loss + logpi_clamp (float): lower bound of the logpis, for numerical stability + """ + super(DiscreteCVAE, self).__init__() + self.q_net = q_net + self.p_net = p_net + self.c_net = c_net + self.decoder = decoder + self.K = K + self.logpi_clamp= logpi_clamp + if recon_loss_fun is None: + self.recon_loss_fun = nn.MSELoss(reduction="none") + else: + self.recon_loss_fun = recon_loss_fun + + def sample(self, condition_inputs, n: int, condition_feature=None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + logp = self.p_net(c)["logp"] + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + p = p/p.sum(dim=-1,keepdim=True) + # z = (-logp).argsort()[...,:n] + # z = F.one_hot(z,self.K) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * N, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * N, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(c.shape[0], n)) + x_out["z"] = z_samples + return x_out + + def predict(self, condition_inputs, condition_feature=None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + + logp = self.p_net(c)["logp"] + z = logp.argmax(dim=-1) + + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=F.one_hot(z,self.K), condition_features=c, **decoder_kwargs) + return x_out + + def forward(self, inputs, condition_inputs, n=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + if n is None: + n = self.K + c = self.c_net(condition_inputs) # [B, ...] + logq = self.q_net(inputs=inputs, condition_features=c)["logq"] + logp = self.p_net(c)["logp"] + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + q = q/q.sum(dim=-1,keepdim=True) + + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + z = (-logq).argsort()[...,:n] + z = F.one_hot(z,self.K) + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + c_tiled = c.unsqueeze(1).repeat(1,n,1) + x_out = self.decoder(latents=z.reshape(-1,self.K), condition_features=c_tiled.reshape(-1,c.shape[-1]), **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(z.shape[0],n)) + return {"x_recons": x_out, "q": q, "p": p, "z": z, "c": c} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class ECDiscreteCVAE(DiscreteCVAE): + def sample(self, condition_inputs, n: int,cond_traj = None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + + + bs,Na = c.shape[:2] + c_joined = TensorUtils.join_dimensions(c,0,2) + logp = self.p_net(c_joined)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + # z = (-logp).argsort()[...,:n] + # z = F.one_hot(z,self.K) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na * n, ...] + c_samples = TensorUtils.repeat_by_expand_at(c_joined, repeats=n, dim=0) # [B * Na * n, ...] + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,n,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na,n)) + + return x_out + + def predict(self, condition_inputs, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + + bs,Na = c.shape[:2] + c_joined = TensorUtils.join_dimensions(c,0,2) + logp = self.p_net(c_joined)["logp"] + z = logp.argmax(dim=-1) + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,2) + x_out = self.decoder(latents=F.one_hot(z,self.K), condition_features=c_joined, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na)) + return x_out + + def forward(self, inputs, condition_inputs, cond_traj, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + c_joined = TensorUtils.join_dimensions(c,0,2) + bs,Na = c.shape[:2] + logp = TensorUtils.reshape_dimensions(self.p_net(c_joined)["logp"],0,1,(bs,Na)) + if inputs is not None: + inputs_tiled = TensorUtils.unsqueeze_expand_at(inputs,Na,1) + inputs_joined = TensorUtils.join_dimensions(inputs_tiled,0,2) + logq = TensorUtils.reshape_dimensions(self.q_net(inputs=inputs_joined, condition_features=c_joined)["logq"],0,1,(bs,Na)) + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + + z = torch.arange(self.K).to(q.device).tile(*q.shape[:-1],1) + z = F.one_hot(z,self.K) + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.K,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + + c_tiled = c[:,:,None].repeat(1,1,self.K,1) + x_out = self.decoder(latents=z.reshape(-1,self.K), condition_features=TensorUtils.join_dimensions(c_tiled,0,3), **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs,Na,self.K)) + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class SceneDiscreteCVAE(DiscreteCVAE): + def __init__(self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + transformer: nn.Module, + K: int, + aggregate_func="max", + recon_loss_fun=None, + logpi_clamp = None, + num_latent_sample = None, + ): + super(SceneDiscreteCVAE,self).__init__(q_net, p_net, c_net, decoder, K, recon_loss_fun, logpi_clamp) + self.transformer = transformer + self.aggregate_func = aggregate_func + if num_latent_sample is None: + num_latent_sample = self.K + assert num_latent_sample<=self.K + self.num_latent_sample = num_latent_sample + + def sample(self, condition_inputs, mask, pos, n: int,cond_traj = None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) [B, Na, D] + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) #[B,n,K] + z = TensorUtils.repeat_by_expand_at(z,repeats=Na,dim=0) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na * n, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * Na * n, ...] + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,n,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na,n)) + + return x_out + + def predict(self, condition_inputs, mask, pos, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + + + z = p.argmax(-1) #[B] + z = F.one_hot(z, self.K) #[B,K] + z = TensorUtils.repeat_by_expand_at(z,repeats=Na,dim=0) #[B*Na,K] + + z = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na , ...] + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,2) + x_out = self.decoder(latents=z, condition_features=c, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na)) + + return x_out + + def forward(self, inputs, condition_inputs, mask, pos, cond_traj=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = TensorUtils.join_dimensions(cond_traj,0,2) #[B*Na,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + if inputs is not None: + # inputs_joined = TensorUtils.join_dimensions(inputs,0,2) + logq = self.q_net(inputs,c,mask,pos)["logq"] + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + + _,z = torch.topk(q,self.num_latent_sample,dim=-1) + p = p.take_along_dim(z,-1) + q = q.take_along_dim(z,-1) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + z = F.one_hot(z,self.K) + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.num_latent_sample,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + + c_tiled = c[:,:,None].repeat(1,1,self.num_latent_sample,1) + x_out = self.decoder( + latents=TensorUtils.unsqueeze_expand_at(z,Na,1).reshape(-1,self.K), + condition_features=TensorUtils.join_dimensions(c_tiled,0,3), + **decoder_kwargs + ) + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs,Na,self.num_latent_sample)) + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class SceneDiscreteCVAEDiverse(SceneDiscreteCVAE): + def __init__(self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + transformer: nn.Module, + latent_embeding: nn.Module, + K: int, + aggregate_func="max", + recon_loss_fun=None, + logpi_clamp = None, + num_latent_sample=None): + super(SceneDiscreteCVAEDiverse,self).__init__(q_net,p_net,c_net, decoder, transformer, K, aggregate_func, recon_loss_fun, logpi_clamp, num_latent_sample) + self.latent_embeding = latent_embeding + def sample(self, condition_inputs, mask, pos, n: int,cond_traj = None, decoder_kwargs=None): + + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) [B, Na, D] + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + res = self.forward(None, condition_inputs, mask, pos, cond_traj, decoder_kwargs) + dis_p = Categorical(probs=res["p"]) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + x_out = {k:v[z] for k,v in res["x_recons"].items()} + return x_out + + def predict(self, condition_inputs, mask, pos, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + z_enum = torch.eye(self.K).repeat(bs,Na,1,1).to(c.device) + z_emb = self.latent_embeding(z_enum) + latent_emb_dim = z_emb.shape[-1] + c_tiled = c.unsqueeze(-2).repeat_interleave(self.K,-2) + cz_tiled = torch.cat((c_tiled,z_emb),-1) + + cz_tiled = TensorUtils.join_dimensions(cz_tiled.transpose(1,2),0,2) + mask_tiled = mask.repeat_interleave(self.K,0) + pos_tiled = pos.repeat_interleave(self.K,0) + + cz = self.transformer(cz_tiled,mask_tiled,pos_tiled)+cz_tiled + if self.aggregate_func == "max": + c_agg = cz.max(1)[0][...,-latent_emb_dim:] + elif self.aggregate_func == "mean": + c_agg = cz.mean(1)[...,-latent_emb_dim:] + logp = self.p_net(c_agg)["logp"] + logp = logp.reshape(bs,self.K) + + + p = torch.exp(logp) + + if self.logpi_clamp is not None: + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + p = p/p.sum(dim=-1,keepdim=True) + z_idx = p.argmax(dim=1) + z_idx = torch.arange(bs).to(c.device)*self.K+z_idx + + + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.K,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(TensorUtils.join_dimensions(cz[z_idx],0,2), **decoder_kwargs) + + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs, Na)) + + return x_out + + def forward(self, inputs, condition_inputs, mask, pos, cond_traj=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = TensorUtils.join_dimensions(cond_traj,0,2) #[B*Na,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + z_enum = torch.eye(self.K).repeat(bs,Na,1,1).to(c.device) + z_emb = self.latent_embeding(z_enum) + latent_emb_dim = z_emb.shape[-1] + c_tiled = c.unsqueeze(-2).repeat_interleave(self.K,-2) + cz_tiled = torch.cat((c_tiled,z_emb),-1) + + cz_tiled = TensorUtils.join_dimensions(cz_tiled.transpose(1,2),0,2) + mask_tiled = mask.repeat_interleave(self.K,0) + pos_tiled = pos.repeat_interleave(self.K,0) + + cz = self.transformer(cz_tiled,mask_tiled,pos_tiled)+cz_tiled + if self.aggregate_func == "max": + c_agg = cz.max(1)[0][...,-latent_emb_dim:] + elif self.aggregate_func == "mean": + c_agg = cz.mean(1)[...,-latent_emb_dim:] + logp = self.p_net(c_agg)["logp"] + logp = logp.reshape(bs,self.K) + + + p = torch.exp(logp) + if inputs is not None: + # inputs_joined = TensorUtils.join_dimensions(inputs,0,2) + inputs_tiled = {k:v.repeat_interleave(self.K,0) for k,v in inputs.items()} + logq = self.q_net(inputs_tiled,cz[...,-latent_emb_dim:],mask_tiled,pos_tiled)["logq"] + logq = logq.reshape(bs,self.K) + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=3.0) + logp = logp.clamp(min=self.logpi_clamp,max=3.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + + _,z = torch.topk(q,self.num_latent_sample,dim=-1) + p = p.take_along_dim(z,-1) + q = q.take_along_dim(z,-1) + + cz = cz.reshape(bs,Na,self.K,-1) + cz = cz.take_along_dim(z[:,None,:,None],2) + cz = cz.transpose(1,2) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.num_latent_sample,1) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(TensorUtils.join_dimensions(cz,0,3), **decoder_kwargs) + + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs, self.num_latent_sample, Na)) + x_out = {k:v.transpose(1,2) for k,v in x_out.items()} + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + + + +def main(): + import diffstack.models.base_models as l5m + + inputs = OrderedDict(trajectories=torch.randn(10, 50, 3)) + condition_inputs = OrderedDict(image=torch.randn(10, 3, 224, 224)) + + condition_dim = 16 + latent_dim = 4 + + prior = FixedGaussianPrior(latent_dim=4) + + map_encoder = l5m.RasterizedMapEncoder( + model_arch="resnet18", + num_input_channels=3, + feature_dim=128 + ) + + q_encoder = l5m.PosteriorEncoder( + condition_dim=condition_dim, + trajectory_shape=(50, 3), + output_shapes=OrderedDict(mu=(latent_dim,), logvar=(latent_dim,)) + ) + c_encoder = l5m.ConditionEncoder( + map_encoder=map_encoder, + trajectory_shape=(50, 3), + condition_dim=condition_dim + ) + decoder = l5m.ConditionDecoder( + condition_dim=condition_dim, + latent_dim=latent_dim, + output_shapes=OrderedDict(trajectories=(50, 3)) + ) + + model = CVAE( + q_net=q_encoder, + c_net=c_encoder, + decoder=decoder, + prior=prior, + target_criterion=nn.MSELoss(reduction="none") + ) + + + outputs = model(inputs=inputs, condition_inputs=condition_inputs) + losses = model.compute_losses(outputs=outputs, targets=inputs) + samples = model.sample(condition_inputs=condition_inputs, n=10) + print() + + traj_encoder = l5m.RNNTrajectoryEncoder( + trajectory_dim=3, + rnn_hidden_size=100 + ) + + c_net = l5m.ConditionNet( + condition_input_shapes=OrderedDict( + map_feature=(map_encoder.output_shape()[-1],) + ), + condition_dim=condition_dim, + ) + + q_net = l5m.PosteriorNet( + input_shapes=OrderedDict( + traj_feature=(traj_encoder.output_shape()[-1],) + ), + condition_dim=condition_dim, + param_shapes=prior.posterior_param_shapes, + ) + + lean_model = CVAE( + q_net=q_net, + c_net=c_net, + decoder=decoder, + prior=prior, + target_criterion=nn.MSELoss(reduction="none") + ) + + map_feats = map_encoder(condition_inputs["image"]) + traj_feats = traj_encoder(inputs["trajectories"]) + input_feats = dict(traj_feature=traj_feats) + condition_feats = dict(map_feature=map_feats) + + outputs = lean_model(inputs=input_feats, condition_inputs=condition_feats) + losses = lean_model.compute_losses(outputs=outputs, targets=inputs) + samples = lean_model.sample(condition_inputs=condition_feats, n=10) + print() + + +def main_discrete(): + import diffstack.models.base_models as l5m + + inputs = OrderedDict(trajectories=torch.randn(10, 50, 3)) + condition_inputs = OrderedDict(image=torch.randn(10, 3, 224, 224)) + + condition_dim = 16 + latent_dim = 20 + + map_encoder = l5m.RasterizedMapEncoder( + model_arch="resnet18", + feature_dim=128 + ) + + q_encoder = l5m.PosteriorEncoder( + condition_dim=condition_dim, + trajectory_shape=(50, 3), + output_shapes=OrderedDict(logq=(latent_dim,)) + ) + p_encoder = l5m.SplitMLP( + input_dim=condition_dim, + layer_dims=(128,128), + output_shapes=OrderedDict(logp=(latent_dim,)) + ) + c_encoder = l5m.ConditionEncoder( + map_encoder=map_encoder, + trajectory_shape=(50, 3), + condition_dim=condition_dim + ) + decoder_MLP = l5m.SplitMLP( + input_dim=condition_dim+latent_dim, + output_shapes=OrderedDict(trajectories=(50, 3)), + layer_dims=(128,128), + output_activation=nn.ReLU, + ) + decoder = l5m.ConditionDecoder(decoder_model=decoder_MLP) + + model = DiscreteCVAE( + q_net=q_encoder, + p_net=p_encoder, + c_net=c_encoder, + decoder=decoder, + K=latent_dim, + ) + + + outputs = model(inputs=inputs, condition_inputs=condition_inputs) + losses = model.compute_losses(outputs=outputs, targets = inputs) + samples = model.sample(condition_inputs=condition_inputs, n=10) + KL_loss = model.compute_kl_loss(outputs) + + + # traj_encoder = l5m.RNNTrajectoryEncoder( + # trajectory_dim=3, + # rnn_hidden_size=100 + # ) + + # c_net = l5m.ConditionNet( + # condition_input_shapes=OrderedDict( + # map_feature=(map_encoder.output_shape()[-1],) + # ), + # condition_dim=condition_dim, + # ) + + # q_net = l5m.PosteriorNet( + # input_shapes=OrderedDict( + # traj_feature=(traj_encoder.output_shape()[-1],) + # ), + # condition_dim=condition_dim, + # param_shapes=prior.posterior_param_shapes, + # ) + + # lean_model = CVAE( + # q_net=q_net, + # c_net=c_net, + # decoder=decoder, + # prior=prior, + # target_criterion=nn.MSELoss(reduction="none") + # ) + + # map_feats = map_encoder(condition_inputs["image"]) + # traj_feats = traj_encoder(inputs["trajectories"]) + # input_feats = dict(traj_feature=traj_feats) + # condition_feats = dict(map_feature=map_feats) + + # outputs = lean_model(inputs=input_feats, condition_inputs=condition_feats) + # losses = lean_model.compute_losses(outputs=outputs, targets=inputs) + # samples = lean_model.sample(condition_inputs=condition_feats, n=10) + # print() + +if __name__ == "__main__": + main_discrete() \ No newline at end of file diff --git a/diffstack/modules/__init__.py b/diffstack/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/modules/metric_models/metrics.py b/diffstack/modules/metric_models/metrics.py new file mode 100644 index 0000000..a448251 --- /dev/null +++ b/diffstack/modules/metric_models/metrics.py @@ -0,0 +1,321 @@ +from collections import OrderedDict +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import pytorch_lightning as pl + +from diffstack.models.learned_metrics import PermuteEBM +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import transform_points_tensor +from diffstack.models.base_models import RasterizedMapUNet, SplitMLP +from diffstack.models.Transformer import SimpleTransformer +import diffstack.utils.algo_utils as AlgoUtils + + + +class EBMMetric(pl.LightningModule): + def __init__(self, algo_config, modality_shapes, do_log=True): + """ + Creates networks and places them into @self.nets. + """ + super(EBMMetric, self).__init__() + self.algo_config = algo_config + self.nets = nn.ModuleDict() + self._do_log = do_log + assert modality_shapes["image"][0] == 15 + + self.nets["ebm"] = PermuteEBM( + model_arch=algo_config.model_architecture, + input_image_shape=modality_shapes["image"], # [C, H, W] + map_feature_dim=algo_config.map_feature_dim, + traj_feature_dim=algo_config.traj_feature_dim, + embedding_dim=algo_config.embedding_dim, + embed_layer_dims=algo_config.embed_layer_dims, + ) + + @property + def checkpoint_monitor_keys(self): + return {"valLoss": "val/losses_infoNCE_loss"} + + def forward(self, obs_dict): + return self.nets["ebm"](obs_dict) + + def _compute_metrics(self, pred_batch, data_batch): + scores = pred_batch["scores"] + pred_inds = torch.argmax(scores, dim=1) + gt_inds = torch.arange(scores.shape[0]).to(scores.device) + cls_acc = torch.mean((pred_inds == gt_inds).float()).item() + + return dict(cls_acc=cls_acc) + + def training_step(self, batch, batch_idx): + """ + Training on a single batch of data. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader and filtered by @process_batch_for_training + + batch_idx (int): training step number - required by some Algos that need + to perform staged training and early stopping + + Returns: + info (dict): dictionary of relevant inputs, outputs, and losses + that might be relevant for logging + """ + batch = batch_utils().parse_batch(batch) + pout = self.nets["ebm"](batch) + losses = self.nets["ebm"].compute_losses(pout, batch) + total_loss = 0.0 + for lk, l in losses.items(): + losses[lk] = l * self.algo_config.loss_weights[lk] + total_loss += losses[lk] + + metrics = self._compute_metrics(pout, batch) + + for lk, l in losses.items(): + self.log("train/losses_" + lk, l) + for mk, m in metrics.items(): + self.log("train/metrics_" + mk, m) + + return { + "loss": total_loss, + "all_losses": losses, + "all_metrics": metrics + } + + def validation_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self.nets["ebm"](batch) + losses = TensorUtils.detach(self.nets["ebm"].compute_losses(pout, batch)) + metrics = self._compute_metrics(pout, batch) + return {"losses": losses, "metrics": metrics} + + def validation_epoch_end(self, outputs) -> None: + for k in outputs[0]["losses"]: + m = torch.stack([o["losses"][k] for o in outputs]).mean() + self.log("val/losses_" + k, m) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m) + + def configure_optimizers(self): + optim_params = self.algo_config.optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def get_metrics(self, obs_dict): + preds = self.forward(obs_dict) + return dict( + scores=preds["scores"].detach() + ) + + +class OccupancyMetric(pl.LightningModule): + def __init__(self, algo_config, modality_shapes): + super(OccupancyMetric, self).__init__() + self.algo_config = algo_config + self.nets = nn.ModuleDict() + self.agent_future_cond = algo_config.agent_future_cond.enabled + if algo_config.agent_future_cond.enabled: + self.agent_future_every_n_frame = algo_config.agent_future_cond.every_n_frame + self.future_num_frames = int(np.floor(algo_config.future_num_frames/self.agent_future_every_n_frame)) + C,H,W = modality_shapes["image"] + modality_shapes["image"] = (C+self.future_num_frames,H,W) + + self.nets["policy"] = RasterizedMapUNet( + model_arch=algo_config.model_architecture, + input_image_shape=modality_shapes["image"], # [C, H, W] + output_channel=algo_config.future_num_frames + ) + + + + @property + def checkpoint_monitor_keys(self): + keys = {"posErr": "val/metrics_pos_selection_err"} + if self.algo_config.loss_weights.pixel_bce_loss > 0: + keys["valBCELoss"] = "val/losses_pixel_bce_loss" + if self.algo_config.loss_weights.pixel_ce_loss > 0: + keys["valCELoss"] = "val/losses_pixel_ce_loss" + return keys + + def rasterize_agent_future(self,obs_dict): + + b, t_h, h, w = obs_dict["image"].shape # [B, C, H, W] + t_f = self.future_num_frames + + # create spatial supervisions + agent_positions = obs_dict["all_other_agents_future_positions"][:,:,::self.agent_future_every_n_frame] + + pos_raster = transform_points_tensor( + agent_positions.reshape(b,-1,2), + obs_dict["raster_from_agent"].float() + ).reshape(b,-1,t_f,2).long() # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + # compute flattened pixel location + hist_image = torch.zeros([b,t_f,h*w],dtype=obs_dict["image"].dtype,device=obs_dict["image"].device) + raster_hist_pos_flat = pos_raster[..., 1] * w + pos_raster[..., 0] # [B, T, A] + raster_hist_pos_flat = (raster_hist_pos_flat * obs_dict["all_other_agents_future_availability"][:,:,::self.agent_future_every_n_frame]).long() + + hist_image.scatter_(dim=2, index=raster_hist_pos_flat.transpose(1,2), src=torch.ones_like(hist_image)) # mark other agents with -1 + + hist_image[:, :, 0] = 0 # correct the 0th index from invalid positions + hist_image[:, :, -1] = 0 # correct the maximum index caused by out of bound locations + + return hist_image.reshape(b, t_f, h, w) + + + def forward(self, obs_dict, mask_drivable=False, num_samples=None, clearance=None): + if self.agent_future_cond: + hist_image = self.rasterize_agent_future(obs_dict) + image = torch.cat([obs_dict["image"],hist_image],1) + else: + image = obs_dict["image"] + + pred_map = self.nets["policy"](image) + + return { + "occupancy_map": pred_map + } + + def compute_likelihood(self, occupancy_map, traj_pos, raster_from_agent): + b, t, h, w = occupancy_map.shape # [B, C, H, W] + + # create spatial supervisions + pos_raster = transform_points_tensor( + traj_pos, + raster_from_agent.float() + ) # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + pos_pixel = torch.floor(pos_raster).float() # round down pixels + + # compute flattened pixel location + pos_pixel_flat = pos_pixel[..., 1] * w + pos_pixel[..., 0] # [B, T] + occupancy_map_flat = occupancy_map.reshape(b, t, -1) + + joint_li_map = torch.softmax(occupancy_map_flat, dim=-1) # [B, T, H * W] + joint_li = torch.gather(joint_li_map, dim=2, index=pos_pixel_flat[:, :, None].long()).squeeze(-1) + indep_li_map = torch.sigmoid(occupancy_map_flat) + indep_li = torch.gather(indep_li_map, dim=2, index=pos_pixel_flat[:, :, None].long()).squeeze(-1) + return { + "indep_likelihood": indep_li, + "joint_likelihood": joint_li + } + + def _compute_losses(self, pred_batch, data_batch): + losses = dict() + pred_map = pred_batch["occupancy_map"] + b, t, h, w = pred_map.shape + + spatial_sup = data_batch["spatial_sup"] + mask = data_batch["target_availabilities"] # [B, T] + # compute pixel classification loss + bce_loss = torch.binary_cross_entropy_with_logits( + input=pred_map, # [B, T, H, W] + target=spatial_sup["traj_spatial_map"], # [B, T, H, W] + ) * mask[..., None, None] + losses["pixel_bce_loss"] = bce_loss.mean() + + ce_loss = torch.nn.CrossEntropyLoss(reduction="none")( + input=pred_map.reshape(b * t, h * w), + target=spatial_sup["traj_position_pixel_flat"].long().reshape(b * t), + ) * mask.reshape(b * t) + + losses["pixel_ce_loss"] = ce_loss.mean() + + return losses + + def _compute_metrics(self, pred_batch, data_batch): + metrics = dict() + spatial_sup = data_batch["spatial_sup"] + + pixel_pred = torch.argmax( + torch.flatten(pred_batch["occupancy_map"], start_dim=2), dim=2 + ) # [B, T] + metrics["pos_selection_err"] = torch.mean( + (spatial_sup["traj_position_pixel_flat"].long() != pixel_pred).float() + ) + + likelihood = self.compute_likelihood( + pred_batch["occupancy_map"], + data_batch["target_positions"], + data_batch["raster_from_agent"] + ) + + metrics["joint_likelihood"] = likelihood["joint_likelihood"].mean() + metrics["indep_likelihood"] = likelihood["indep_likelihood"].mean() + + metrics = TensorUtils.to_numpy(metrics) + for k, v in metrics.items(): + metrics[k] = float(v) + return metrics + + def training_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self.forward(batch) + batch["spatial_sup"] = AlgoUtils.get_spatial_trajectory_supervision(batch) + losses = self._compute_losses(pout, batch) + total_loss = 0.0 + for lk, l in losses.items(): + loss = l * self.algo_config.loss_weights[lk] + self.log("train/losses_" + lk, loss) + total_loss += loss + + with torch.no_grad(): + metrics = self._compute_metrics(pout, batch) + for mk, m in metrics.items(): + self.log("train/metrics_" + mk, m) + + return total_loss + + def validation_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self(batch) + batch["spatial_sup"] = AlgoUtils.get_spatial_trajectory_supervision(batch) + losses = TensorUtils.detach(self._compute_losses(pout, batch)) + metrics = self._compute_metrics(pout, batch) + return {"losses": losses, "metrics": metrics} + + def validation_epoch_end(self, outputs) -> None: + for k in outputs[0]["losses"]: + m = torch.stack([o["losses"][k] for o in outputs]).mean() + self.log("val/losses_" + k, m) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m) + + def configure_optimizers(self): + optim_params = self.algo_config.optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def get_metrics(self, obs_dict,horizon=None): + occup_map = self.forward(obs_dict)["occupancy_map"] + b, t, h, w = occup_map.shape # [B, C, H, W] + if horizon is None: + horizon = t + else: + assert horizon<=t + li = self.compute_likelihood(occup_map, obs_dict["target_positions"], obs_dict["raster_from_agent"]) + li["joint_likelihood"] = li["joint_likelihood"][:,:horizon].mean(dim=-1).detach() + li["indep_likelihood"] = li["indep_likelihood"][:,:horizon].mean(dim=-1).detach() + + return li \ No newline at end of file diff --git a/diffstack/modules/module.py b/diffstack/modules/module.py new file mode 100644 index 0000000..efea2cf --- /dev/null +++ b/diffstack/modules/module.py @@ -0,0 +1,438 @@ +import dill +import numpy as np +import torch + +from collections import OrderedDict +from enum import IntEnum +from typing import Dict, Optional, Union, Any, List, Set + +from diffstack.utils.model_registrar import ModelRegistrar + + +class RunMode(IntEnum): + TRAIN = 0 + VALIDATE = 1 + INFER = 2 + + +class DataFormat(object): + def __init__(self, required_elements: Set[str]) -> None: + self.required_elements = required_elements + + def satisfied_by(self, data_dict: Dict[str, Any]) -> bool: + return all(x in data_dict for x in self.required_elements) + + def __iter__(self) -> str: + for x in self.required_elements: + yield x + + def for_run_mode(self, run_mode: RunMode): + elements = [] + for k in self.required_elements: + ksplit = k.split(":") + if len(ksplit) == 1: + elements.append(k) + elif len(ksplit) == 2: + if ksplit[1].upper() == run_mode.name: + elements.append(ksplit[0]) + return DataFormat(elements) + + +class Module(torch.nn.Module): + """Abstract module in a differentiable stack. + + Inheriting classes need to implement: + - input_format + - output_format + - train_step() + - validate_step() + - infer_step() + + - train/validate/infer methods. + """ + + @property + def name(self) -> str: + self.__class__.__name__ + + @property + def input_format(self) -> DataFormat: + """Required input keys specified as a set of strings wrapped as a DataFormat. + + The naming convention `my_input:run_mode` is used to identify a key + `my_input` that is only required for run mode `run_mode`. + + Example: + return DataFormat(["rgb_image", "pointcloud", "label:train"]) + """ + return None + + @property + def output_format(self) -> DataFormat: + """Output keys specified as a set of strings wrapped as a DataFormat. + + The naming convention `my_output:run_mode` is used to identify a key + `my_output` that is only provided for run mode `run_mode`. + + Example: + return DataFormat(["prediction", "loss:train", "ml_prediction:infer"]) + """ + return None + + def __init__( + self, + model_registrar: Optional[ModelRegistrar] = None, + hyperparams: Optional[Dict[str, Any]] = None, + log_writer: Optional[Any] = None, + device: Optional[str] = None, + input_mappings: Dict[str, str] = {}, + ) -> None: + """ + Args: + model_registrar (ModelRegistrar): handles the registration of trainable parameters + hyperparams (dict): config parameters + log_writer (wandb.Run): `wandb.Run` object for logging or None. + device (str): torch device + input_mappings (dict): a remapping of input names of the format {target_name: source_name}. + This is used when connecting multiple modules and we need to rename the outputs of the + previous module to the inputs of this module. + """ + super().__init__() + self.model_registrar = model_registrar + self.hyperparams = hyperparams + self.log_writer = log_writer + self.device = device + self.input_mappings = input_mappings + + # Initialize epoch counter + self.curr_iter = 0 + self.curr_epoch = 0 + + def apply_input_mappings( + self, inputs: Dict, run_mode: RunMode, allow_partial: bool = True + ): + mapped_inputs = {} + for k in self.input_format.for_run_mode(run_mode): + if k in self.input_mappings: + if self.input_mappings[k] in inputs: + mapped_inputs[k] = inputs[self.input_mappings[k]] + elif not allow_partial: + raise ValueError( + f"Key `{k}` is remapped `{k}`<--`{self.input_mappings[k]}` but " + + f"there is no key `{self.input_mappings[k]}` in inputs.\n " + + f"inputs={list(inputs.keys())};\n " + + f"input_mappings={self.input_mappings}" + ) + elif k in inputs: + mapped_inputs[k] = inputs[k] + elif "input." + k in inputs: + mapped_inputs[k] = inputs["input." + k] + elif not allow_partial: + raise ValueError( + f"Key `{k}` is not found in inputs and input_mappings.\n " + + f"inputs={list(inputs.keys())};\n " + + f"input_mappings={self.input_mappings}" + ) + return mapped_inputs + + # Optional functions for tracking training iteration/epoch and annealers. + # This logic is inherited from Trajectron++. + def set_curr_iter(self, curr_iter): + self.curr_iter = curr_iter + + def set_curr_epoch(self, curr_epoch): + self.curr_epoch = curr_epoch + + def set_annealing_params(self): + pass + + def step_annealers(self, node_type=None): + pass + + def forward(self, inputs: Dict, **kwargs) -> Dict: + return self.train_step(inputs, **kwargs) + + # Abstract methods that children classes should implement. + def train_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.TRAIN, **kwargs) + + def validate_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.VALIDATE, **kwargs) + + def infer_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.INFER, **kwargs) + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + raise NotImplementedError + + def set_eval(): + pass + + def set_train(): + pass + + def reset(self): + pass + + +class ModuleSequence(Module): + """A sequence of stack modules defined by an ordered dict of {name: Module}. + + The output of component N will be fed to the input of component N+1. + When the output names of module N does not match the input names of module N+1 + we can specify the name mapping with the `input_mappings` argument of module N+1. + The outputs of all previous modules can be referenced by `modulename.outputname`. + The overall inputs can be referenced by `input.inputname`. + + The output will the sequence will contain the outputs of the last component, as + well as the outputs of all components in the `modulename.outputname` format. + + Example: + We have two modules with the following inputs and outputs: + MyPredictor: {'agent_history'} -> {'most_likely_pred'} + MyPlanner: {'prediction', 'ego_state', 'goal'} -> {'plan_x'} + + We can sequence them in the followin way. + stack = ModuleSequence(OrderedDict( + pred=MyPredictor(), + plan=MyPlanner(input_mappings={ + 'prediction': 'pred.most_likely_pred' + 'ego_state': 'input.ego_state', + 'goal': 'input.ego_goal', + }))) + + input_dict = {'agent_history', 'ego_state', 'ego_goal'} + output_dict = stack.train_step(input_dict) + + Now `output_dict.keys()` will contain [ + 'input.agent_history', + 'input.ego_state', + 'input.ego_goal', + 'pred.most_likely_pred', + 'plan.plan_x', + 'plan_x' + ] + """ + + @property + def input_format(self) -> DataFormat: + list(self.components.values())[0].input_format + + @property + def output_format(self) -> DataFormat: + list(self.components.values())[-1].output_format + + def __init__( + self, + components: OrderedDict, + model_registrar, + hyperparams, + log_writer, + device, + input_mappings: Dict[str, str] = {}, + ) -> None: + super().__init__( + model_registrar, + hyperparams, + log_writer, + device, + input_mappings=input_mappings, + ) + self.components = components + + if "input" in self.components: + raise ValueError( + 'Name "input" is reserved for the overall input to the module sequence.' + ) + + def validate_interfaces(self, desired_input: DataFormat) -> bool: + data_dict = {k: None for k in desired_input} + for component in self.components.values(): + if component.input_format.satisfied_by(data_dict): + data_dict.update({k: None for k in component.output_format}) + else: + missing_inputs = [ + x + for x in component.input_format.required_elements + if x not in data_dict + ] + raise ValueError( + f"Component '{component.name}' missing input(s): {missing_inputs}" + ) + + return data_dict + + def sequence_components( + self, + inputs: Dict, + run_mode: RunMode, + pre_module_cb: Optional[callable] = None, + post_module_cb: Optional[callable] = None, + **kwargs, + ): + """Sequence the ordered dict of components by connecting outputs to inputs.""" + all_outputs = {"input." + k: v for k, v in inputs.items()} + next_inputs = {**all_outputs, **inputs, **kwargs} + + for comp_i, (name, component) in enumerate(self.components.items()): + component: Module + + if pre_module_cb is not None: + pre_module_cb(name, component) + + if component.input_format is None: + inputs = next_inputs + else: + inputs = component.apply_input_mappings(next_inputs, run_mode) + + if run_mode == RunMode.TRAIN: + output = component.train_step(inputs, **kwargs) + elif run_mode == RunMode.VALIDATE: + output = component.validate_step(inputs, **kwargs) + elif run_mode == RunMode.INFER: + output = component.infer_step(inputs, **kwargs) + else: + raise ValueError(f"Unknown mode {run_mode}") + + if post_module_cb is not None: + post_module_cb(name, component) + + # Add to all_outputs + all_outputs.update({f"{name}.{k}": v for k, v in output.items()}) + + # Construct input from all_outputs and the previous outputs. + next_inputs = {**all_outputs, **output} + + return next_inputs + + def dry_run( + self, + input_keys: List[str] = None, + run_mode: Optional[RunMode] = None, + check_output: bool = True, + raise_error: bool = True, + ) -> List[str]: + """Checks that all inputs are defined in module sequence. + + We only verify input and output based on module.input_format and + module.output_format. We do not check if the modules correctly + define their respective input_format and output_format. + + Args: + input_keys: list of input keys we will feed to the module. If None we + will use `self.input_format`. + run_mode: run mode. If not specified we will check for all possible run modes. + check_output: check that `self.output_format` is satisfied by last component output. + raise_error: will raise error for an issues if True + Returns: + list of found issues represented as strings + Raises: + ValueError if some inputs are not defined. + """ + issues = [] + if run_mode is None: + for run_mode in RunMode: + issues.extend( + self.dry_run(input_keys, run_mode, check_output, raise_error=False) + ) + if issues: + break + + else: + if input_keys is None: + input_keys = list(self.input_format) + inputs = {k: None for k in input_keys} + all_outputs = {"input." + k: v for k, v in inputs.items()} + next_inputs = {**all_outputs, **inputs} + + component: Module + for name, component in self.components.items(): + inputs = component.apply_input_mappings( + next_inputs, run_mode, allow_partial=True + ) + input_format = component.input_format.for_run_mode(run_mode) + if not input_format.satisfied_by(inputs): + issues.append( + f"Component {name}({run_mode.name}): \n Required inputs: {list(input_format)}\n Provided inputs: {list(inputs.keys())}" + ) + + output = { + k: None for k in component.output_format.for_run_mode(run_mode) + } + # Add to all_outputs + all_outputs.update({f"{name}.{k}": v for k, v in output.items()}) + # Construct input from all_outputs and the previous outputs. + next_inputs = {**all_outputs, **output} + + if check_output: + output_format = self.output_format.for_run_mode(run_mode) + if not output_format.satisfied_by(next_inputs): + issues.append( + f"Sequence {self.name}({run_mode.name}): \n Required outputs: {list(output_format)}\n Provided outputs: {list(next_inputs.keys())}" + ) + + if raise_error and issues: + print("\n".join(issues)) + raise ValueError("\n".join(issues)) + + return issues + + def set_curr_iter(self, curr_iter): + super().set_curr_iter(curr_iter) + for comp in self.components.values(): + comp.set_curr_iter(curr_iter) + + def set_curr_epoch(self, curr_epoch): + super().set_curr_epoch(curr_epoch) + for comp in self.components.values(): + comp.set_curr_epoch(curr_epoch) + + def set_annealing_params(self): + super().set_annealing_params() + for comp in self.components.values(): + comp.set_annealing_params() + + def step_annealers(self, node_type=None): + super().step_annealers(node_type) + for comp in self.components.values(): + comp.step_annealers(node_type) + + def train_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.TRAIN, **kwargs) + + def validate_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.VALIDATE, **kwargs) + + def infer_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.INFER, **kwargs) + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + return self.sequence_components(inputs, run_mode=run_mode, **kwargs) + + def __getstate__(self): + # Custom getstate for pickle that allows for lambda functions. + import dill # reimport in case not available in forked function + + return dill.dumps(self.__dict__) + + def __setstate__(self, state): + # Custom setstate for pickle that allows for lambda functions. + state = dill.loads(state) + self.__dict__.update(state) + + def __str__(self): + return f"Module sequence (device={self.device}) \n " + "\n ".join( + self.components.keys() + ) + + def set_eval(self): + for k, v in self.components.items(): + v.set_eval() + + def set_train(self): + for k, v in self.components.items(): + v.set_train() + + def reset(self): + for k, v in self.components.items(): + v.reset() diff --git a/diffstack/modules/predictors/CTT.py b/diffstack/modules/predictors/CTT.py new file mode 100644 index 0000000..f788fe8 --- /dev/null +++ b/diffstack/modules/predictors/CTT.py @@ -0,0 +1,1673 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import wandb +import os +from functools import partial +from typing import Dict +from collections import OrderedDict, defaultdict + +from diffstack.dynamics.unicycle import Unicycle, Unicycle_xyvsc +from diffstack.dynamics.unicycle import Unicycle +from diffstack.modules.module import Module, DataFormat, RunMode + +from diffstack.utils.utils import removeprefix +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.metrics as Metrics +from diffstack.utils.geometry_utils import ratan2 +import diffstack.utils.lane_utils as LaneUtils +from diffstack.utils.batch_utils import batch_utils +import diffstack.utils.model_utils as ModelUtils + +from trajdata.utils.map_utils import LaneSegRelation +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +from diffstack.utils.vis_utils import plot_scene_open_loop, animate_scene_open_loop + +from diffstack.utils.loss_utils import collision_loss, loss_clip +from pytorch_lightning.loggers import WandbLogger +from pathlib import Path + + +from diffstack.models.CTT import ( + CTT, + TFvars, + FeatureAxes, + GNNedges, +) +from trajdata.data_structures import AgentType +from diffstack.utils.homotopy import ( + HomotopyType, + identify_pairwise_homotopy, + HOMOTOPY_THRESHOLD, +) +from bokeh.plotting import figure, output_file, save +from bokeh.models import Range1d +from bokeh.io import export_png + + +class CTTTrafficModel(Module): + @property + def input_format(self) -> DataFormat: + return DataFormat(["scene_batch"]) + + @property + def output_format(self) -> DataFormat: + return DataFormat(["pred", "pred_dist", "pred_ml", "sample_modes", "step_time"]) + + @property + def checkpoint_monitor_keys(self): + return {"valLoss": "val/losses_predictor_marginal_lm_loss"} + + def __init__(self, model_registrar, cfg, log_writer, device, input_mappings={}): + super(CTTTrafficModel, self).__init__( + model_registrar, cfg, log_writer, device, input_mappings=input_mappings + ) + + self.bu = batch_utils(rasterize_mode="none") + self.modality_shapes = self.bu.get_modality_shapes(cfg) + + self.hist_lane_relation = LaneUtils.LaneRelationFromCfg(cfg.hist_lane_relation) + self.fut_lane_relation = LaneUtils.LaneRelationFromCfg(cfg.fut_lane_relation) + self.cfg = cfg + self.nets = nn.ModuleDict() + + self.create_nets(cfg) + self.fp16 = cfg.fp16 if "fp16" in cfg else False + + if "checkpoint" in cfg and cfg.checkpoint["enabled"]: + checkpoint = torch.load(cfg.checkpoint["path"]) + predictor_dict = { + removeprefix(k, "components.predictor."): v + for k, v in checkpoint["state_dict"].items() + if k.startswith("components.predictor.") + } + self.load_state_dict(predictor_dict) + + self.device = device + self.nets["policy"] = self.nets["policy"].to(self.device) + + # setup loss functions + + self.lane_mode_loss = nn.CrossEntropyLoss( + reduction="none", label_smoothing=0.01 + ) + self.homotopy_loss = nn.CrossEntropyLoss(reduction="none", label_smoothing=0.01) + self.joint_mode_loss = nn.CrossEntropyLoss( + reduction="none", label_smoothing=0.05 + ) + self.traj_loss = nn.MSELoss(reduction="none") + self.unicycle_model = Unicycle(cfg.step_time) + + def create_nets(self, cfg): + enc_var_axes = { + TFvars.Agent_hist: "B,A,T,F", + TFvars.Lane: "B,L,F", + } + agent_ntype = len(AgentType) + auxdim_l = 4 * cfg.num_lane_pts # (x,y,s,c) x num_pts + + agent_raw_dim = agent_ntype + 5 # agent type, l,w,v,a,r + + enc_attn_attributes = OrderedDict() + # each attention operation is done on a pair of variables, e.g. (agent_hist, agent_hist), on a specific axis, e.g. T + # to specify the recipe for factorized attention, we need to specify the following: + # 1. the dimension of the edge feature + # 2. the edge function (if any) + # 3. the number of separate attention mechanisms (in case we need to separate the self-attention from the cross-attention or more generally, if we need to separate the attention for different types of edges) + # 4. whether to use normalization for the embedding + + d_lm_hist = len(self.hist_lane_relation) + d_lm_fut = len(self.fut_lane_relation) + d_ll = len(LaneSegRelation) + d_static = 2 + 1 + len(AgentType) + a2a_edge_func = partial( + ModelUtils.agent2agent_edge, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + l2l_edge_func = partial( + ModelUtils.lane2lane_edge, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + if cfg.a2l_edge_type == "proj": + a2l_edge_func = partial( + ModelUtils.agent2lane_edge_proj, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + a2l_edge_dim = cfg.edge_dim.a2l + elif cfg.a2l_edge_type == "attn": + self.nets["a2l_edge"] = ModelUtils.Agent2Lane_emb_attn( + edge_dim=4, + n_embd=cfg.a2l_n_embd, + agent_feat_dim=d_static, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + Lmax=cfg.num_lane_pts, + ) + a2l_edge_dim = cfg.a2l_n_embd + a2l_edge_func = self.nets["a2l_edge"].forward + + self.edge_func = dict(a2a=a2a_edge_func, l2l=l2l_edge_func, a2l=a2l_edge_func) + enc_attn_attributes[(TFvars.Agent_hist, TFvars.Agent_hist, FeatureAxes.T)] = ( + cfg.edge_dim.a2a, + a2a_edge_func, + 1, + False, + ) + enc_attn_attributes[(TFvars.Agent_hist, TFvars.Agent_hist, FeatureAxes.A)] = ( + cfg.edge_dim.a2a, + a2a_edge_func, + cfg.attn_ntype.a2a, + False, + ) + enc_attn_attributes[(TFvars.Lane, TFvars.Lane, FeatureAxes.L)] = ( + cfg.edge_dim.l2l + d_ll, + None, + cfg.attn_ntype.l2l, + False, + ) + enc_attn_attributes[ + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ] = (a2l_edge_dim + d_lm_hist, None, cfg.attn_ntype.a2l, False) + + enc_transformer_kwargs = dict( + n_embd=cfg.n_embd, + n_head=cfg.n_head, + PE_mode=cfg.PE_mode, + use_rpe_net=cfg.use_rpe_net, + attn_attributes=enc_attn_attributes, + var_axes=enc_var_axes, + attn_pdrop=cfg.encoder.attn_pdrop, + resid_pdrop=cfg.encoder.resid_pdrop, + MAX_T=cfg.future_num_frames + cfg.history_num_frames + 2, + ) + + lane_margins_dim = 5 + + # decoder design + dec_attn_attributes = OrderedDict() + + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Agent_hist, (FeatureAxes.T, FeatureAxes.T)) + ] = (cfg.edge_dim.a2a, a2a_edge_func, 1, False) + + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A) + ] = (cfg.edge_dim.a2a + len(HomotopyType) * 2, None, 2, False) + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ] = (a2l_edge_dim + d_lm_fut + lane_margins_dim, None, 1, False) + dec_var_axes = { + TFvars.Agent_hist: "B,A,T,F", + TFvars.Lane: "B,L,F", + TFvars.Agent_future: "B,A,T,F", + } + dec_transformer_kwargs = dict( + n_embd=cfg.n_embd, + n_head=cfg.n_head, + PE_mode=cfg.PE_mode, + use_rpe_net=cfg.use_rpe_net, + attn_attributes=dec_attn_attributes, + var_axes=dec_var_axes, + attn_pdrop=cfg.decoder.attn_pdrop, + resid_pdrop=cfg.decoder.resid_pdrop, + ) + # embedding functions of raw variables + embed_funcs = { + TFvars.Agent_hist: ModelUtils.Agent_emb(agent_raw_dim, cfg.n_embd), + TFvars.Lane: ModelUtils.Lane_emb( + auxdim_l, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + TFvars.Agent_future: ModelUtils.Agent_emb(agent_raw_dim, cfg.n_embd), + GNNedges.Agenthist2Agenthist: ModelUtils.Agent2Agent_emb( + cfg.edge_dim.a2a, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Agentfuture: ModelUtils.Agent2Agent_emb( + cfg.edge_dim.a2a + len(HomotopyType) * 2, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + if cfg.a2l_edge_type == "proj": + embed_funcs.update( + { + GNNedges.Agenthist2Lane: ModelUtils.Agent2Lane_emb_proj( + cfg.edge_dim.a2l, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Lane: ModelUtils.Agent2Lane_emb_proj( + cfg.edge_dim.a2l + d_lm_fut + lane_margins_dim, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + ) + else: + embed_funcs.update( + { + GNNedges.Agenthist2Lane: ModelUtils.Agent2Lane_emb_attn( + 4, + cfg.a2l_n_embd, + d_static, + output_dim=cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Lane: ModelUtils.Agent2Lane_emb_attn( + 4, + cfg.a2l_n_embd, + d_static, + output_dim=cfg.n_embd, + aux_edge_dim=d_lm_fut + lane_margins_dim, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + ) + + enc_GNN_attributes = OrderedDict() + enc_GNN_attributes[(GNNedges.Agenthist2Agenthist, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + enc_GNN_attributes[(GNNedges.Agenthist2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + enc_GNN_attributes[(GNNedges.Agenthist2Lane, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + enc_GNN_attributes[ + (GNNedges.Agenthist2Agenthist, "node", TFvars.Agent_hist) + ] = ([cfg.n_embd * 2], None, cfg.encoder.pooling, None) + + node_n_embd = defaultdict(lambda: cfg.n_embd) + edge_n_embd = defaultdict(lambda: cfg.n_embd) + + enc_edge_var = { + GNNedges.Agenthist2Agenthist: (TFvars.Agent_hist, TFvars.Agent_hist), + GNNedges.Agenthist2Lane: (TFvars.Agent_hist, TFvars.Lane), + } + dec_edge_var = { + GNNedges.Agenthist2Agenthist: (TFvars.Agent_hist, TFvars.Agent_hist), + GNNedges.Agenthist2Lane: (TFvars.Agent_hist, TFvars.Lane), + GNNedges.Agentfuture2Agentfuture: ( + TFvars.Agent_future, + TFvars.Agent_future, + ), + GNNedges.Agentfuture2Lane: (TFvars.Agent_future, TFvars.Lane), + } + enc_GNN_kwargs = dict( + var_axes=enc_var_axes, + GNN_attributes=enc_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=edge_n_embd, + edge_var=enc_edge_var, + ) + + JM_GNN_attributes = OrderedDict() + JM_GNN_attributes[(GNNedges.Agenthist2Agenthist, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Lane, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Agenthist, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + + JM_edge_n_embd = { + GNNedges.Agenthist2Agenthist: cfg.n_embd + cfg.encoder.mode_embed_dim, + GNNedges.Agenthist2Lane: cfg.n_embd + cfg.encoder.mode_embed_dim, + } + + JM_GNN_kwargs = dict( + var_axes=enc_var_axes, + GNN_attributes=JM_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=JM_edge_n_embd, + edge_var=enc_edge_var, + ) + + enc_output_params = dict( + pooling_T=cfg.encoder.pooling, + Th=cfg.history_num_frames + 1, + lane_mode=dict( + n_head=4, + hidden_dim=[cfg.n_embd], + ), + homotopy=dict( + n_head=4, + hidden_dim=[cfg.n_embd], + ), + joint_mode=dict( + n_head=4, + jm_GNN_nblock=cfg.encoder.jm_GNN_nblock, + GNN_kwargs=JM_GNN_kwargs, + num_joint_samples=cfg.encoder.num_joint_samples, + num_joint_factors=cfg.encoder.num_joint_factors, + ), + mode_embed_dim=cfg.encoder.mode_embed_dim, + null_lane_mode=cfg.encoder.null_lane_mode, + PE_mode=cfg.PE_mode, + ) + self.null_lane_mode = cfg.encoder.null_lane_mode + + dec_GNN_attributes = OrderedDict() + if cfg.decoder.GNN_enabled: + dec_GNN_attributes[(GNNedges.Agentfuture2Agentfuture, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + dec_GNN_attributes[(GNNedges.Agentfuture2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + dec_GNN_attributes[ + (GNNedges.Agentfuture2Lane, "node", TFvars.Agent_future) + ] = ([cfg.n_embd * 2], None, cfg.decoder.pooling, None) + dec_GNN_attributes[ + (GNNedges.Agentfuture2Agentfuture, "node", TFvars.Agent_future) + ] = ([cfg.n_embd * 2], None, cfg.decoder.pooling, None) + + dec_GNN_kwargs = dict( + var_axes=dec_var_axes, + GNN_attributes=dec_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=edge_n_embd, + edge_var=dec_edge_var, + ) + + dyn = dict() + name_type_table = { + "vehicle": AgentType.VEHICLE, + "pedestrian": AgentType.PEDESTRIAN, + "bicycle": AgentType.BICYCLE, + "motorcycle": AgentType.MOTORCYCLE, + } + + for k, v in cfg.decoder.dyn.items(): + if v == "unicycle": + dyn[name_type_table[k]] = Unicycle(cfg.step_time) + elif v == "unicycle_xyvsc": + dyn[name_type_table[k]] = Unicycle_xyvsc(cfg.step_time) + elif v == "DI_unicycle": + dyn[name_type_table[k]] = Unicycle(cfg.step_time, max_steer=1e3) + else: + dyn[name_type_table[k]] = None + + self.weighted_consistency_loss = cfg.weighted_consistency_loss + dec_output_params = dict( + arch=cfg.decoder.arch, + dyn=dyn, + num_layers=cfg.decoder.num_layers, + lstm_hidden_size=cfg.decoder.lstm_hidden_size, + mlp_hidden_dims=cfg.decoder.mlp_hidden_dims, + traj_dim=cfg.decoder.traj_dim, + dt=cfg.step_time, + Tf=cfg.future_num_frames, + decode_num_modes=cfg.decoder.decode_num_modes, + AR_step_size=cfg.decoder.AR_step_size, + AR_update_mode=cfg.decoder.AR_update_mode, + LR_sample_hack=cfg.LR_sample_hack, + dec_rounds=cfg.decoder.dec_rounds, + ) + + assert ( + cfg.decoder.AR_step_size == 1 + ) # for now, we only support AR_step_size=1 for non-auto-regressive mode + self.Tf_mode = cfg.future_num_frames + self.nets["policy"] = CTT( + n_embd=cfg.n_embd, + embed_funcs=embed_funcs, + enc_nblock=cfg.enc_nblock, + dec_nblock=cfg.dec_nblock, + enc_transformer_kwargs=enc_transformer_kwargs, + enc_GNN_kwargs=enc_GNN_kwargs, + dec_transformer_kwargs=dec_transformer_kwargs, + dec_GNN_kwargs=dec_GNN_kwargs, + enc_output_params=enc_output_params, + dec_output_params=dec_output_params, + hist_lane_relation=self.hist_lane_relation, + fut_lane_relation=self.fut_lane_relation, + max_joint_cardinality=self.cfg.max_joint_cardinality, + classify_a2l_4all_lanes=self.cfg.classify_a2l_4all_lanes, + edge_func=self.edge_func, + ) + + def set_eval(self): + self.nets["policy"].eval() + + def set_train(self): + self.nets["policy"].train() + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + if "parsed_batch" in inputs: + parsed_batch = inputs["parsed_batch"] + else: + if ( + isinstance(inputs["scene_batch"], dict) + and "batch" in inputs["scene_batch"] + ): + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"]["batch"].astype(torch.float) + ) + else: + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"].astype(torch.float) + ) + + if self.fp16: + for k, v in parsed_batch.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.float32: + parsed_batch[k] = v.half() + + # tic = torch_utils.tic(timer=self.hyperparams.debug.timer) + parsed_batch_torch = { + k: v.to(self.device) + for k, v in parsed_batch.items() + if isinstance(v, torch.Tensor) + } + + parsed_batch.update(parsed_batch_torch) + dt = parsed_batch["dt"][0].item() + B, N, Th = parsed_batch["agent_hist"].shape[:3] + + Tf = parsed_batch["agent_fut"].size(2) + device = parsed_batch["agent_hist"].device + agent_type = ( + F.one_hot( + parsed_batch["agent_type"].masked_fill( + parsed_batch["agent_type"] < 0, 0 + ), + len(AgentType), + ) + .float() + .to(device) + ) + agent_type.masked_fill_((parsed_batch["agent_type"] < 0).unsqueeze(-1), 0) + agent_size = parsed_batch["extent"][..., :2] + static_feature = torch.cat([agent_size, agent_type], dim=-1) + hist_mask = parsed_batch["hist_mask"] + static_feature_tiled_h = static_feature.unsqueeze(2).repeat_interleave( + Th, 2 + ) * hist_mask.unsqueeze(-1) + + hist_xy = parsed_batch["agent_hist"][..., :2] + hist_yaw = torch.arctan2( + parsed_batch["agent_hist"][..., 6:7], parsed_batch["agent_hist"][..., 7:8] + ) + hist_sc = parsed_batch["agent_hist"][..., 6:8] + # normalize hist_sc + hist_sc = GeoUtils.normalize_sc(hist_sc) + + x_hist_uni = self.unicycle_model.get_state(hist_xy, hist_yaw, hist_mask) + u_mask = hist_mask[..., :-1] * hist_mask[..., 1:] + u_hist_uni = self.unicycle_model.inverse_dyn( + x_hist_uni[..., :-1, :], x_hist_uni[..., 1:, :], mask=u_mask + ) + u_hist_uni = torch.cat([u_hist_uni[..., :1, :], u_hist_uni], dim=-2) + + hist_v = self.unicycle_model.state2vel(x_hist_uni) + hist_acce = u_hist_uni[..., :1] + + hist_xyvsc = torch.cat([hist_xy, hist_v, hist_sc], dim=-1) + hist_xysc = torch.cat([hist_xy, hist_sc], dim=-1) + hist_h = ratan2(hist_sc[..., :1], hist_sc[..., 1:2]) + hist_r = (GeoUtils.round_2pi(hist_h[..., 1:, :] - hist_h[..., :-1, :])) / dt + hist_r = torch.cat([hist_r[..., 0:1, :], hist_r], dim=-2) + hist_feature = torch.cat( + [hist_v, hist_acce, hist_r, static_feature_tiled_h], dim=-1 + ) + lane_xyh = parsed_batch["lane_xyh"] + M = lane_xyh.size(1) + lane_xysc = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + lane_adj = parsed_batch["lane_adj"].type(torch.int64) + lane_mask = parsed_batch["lane_mask"] + + raw_vars = { + TFvars.Agent_hist: hist_feature, + TFvars.Lane: lane_xysc.reshape(B, M, -1), + } + hist_aux = torch.cat([hist_xyvsc, static_feature_tiled_h], -1) + lane_aux = lane_xysc.view(B, M, -1) + + aux_xs = { + TFvars.Agent_hist: hist_aux, + TFvars.Lane: lane_aux, + } + a2l_edge = self.edge_func["a2l"]( + hist_aux.transpose(1, 2).reshape(B * Th, N, -1), + lane_aux.repeat_interleave(Th, 0), + ) + + hist_lane_agent_flag, _ = self.hist_lane_relation.categorize_lane_relation_pts( + TensorUtils.join_dimensions(hist_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + TensorUtils.join_dimensions(hist_mask, 0, 2), + lane_mask.repeat_interleave(N, 0), + ) + + hist_lane_agent_flag = hist_lane_agent_flag.view(B, N, M, Th, -1) + a2l_edge = torch.cat( + [ + a2l_edge, + hist_lane_agent_flag.permute(0, 3, 1, 2, 4).reshape(B * Th, N, M, -1), + ], + -1, + ) + l2l_edge = self.edge_func["l2l"](lane_aux, lane_aux) + l2l_edge = torch.cat([l2l_edge, F.one_hot(lane_adj, len(LaneSegRelation))], -1) + agent_hist_mask = parsed_batch["hist_mask"] + + # cross_masks = { + # (TFvars.Agent_hist,TFvars.Lane,(FeatureAxes.A,FeatureAxes.L)): [agent_lane_mask1,agent_lane_mask2], + # } + cross_masks = dict() + var_masks = { + TFvars.Agent_hist: agent_hist_mask.float(), + TFvars.Lane: lane_mask.float(), + } + enc_edges = { + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): a2l_edge, + (TFvars.Lane, TFvars.Lane, FeatureAxes.L): l2l_edge, + } + + frame_indices = { + TFvars.Agent_hist: torch.arange(Th, device=device)[None, None, :].repeat( + B, N, 1 + ), + } + + if ( + "agent_fut" in parsed_batch + and "fut_mask" in parsed_batch + and parsed_batch["fut_mask"].any() + and run_mode in [RunMode.TRAIN, RunMode.VALIDATE] + ): + fut_xy = parsed_batch["agent_fut"][..., :2] + fut_sc = parsed_batch["agent_fut"][..., 6:8] + fut_sc = GeoUtils.normalize_sc(fut_sc) + fut_xysc = torch.cat([fut_xy, fut_sc], -1) + fut_mask = parsed_batch["fut_mask"] + + mode_valid_flag = fut_mask.all(-1) + end_points = fut_xysc[:, :, -1] # Only look at final time for GT! + + GT_lane_mode, _ = self.fut_lane_relation.categorize_lane_relation_pts( + end_points.reshape(B * N, 1, 4), + lane_xysc.repeat_interleave(N, 0), + fut_mask.any(-1).reshape(B * N, 1), + lane_mask.repeat_interleave(N, 0), + force_select=not self.null_lane_mode, + ) + # You could have two lanes that it is both on + + GT_lane_mode = GT_lane_mode.squeeze(-2).argmax(-1).reshape(B, N, M) + if self.null_lane_mode: + GT_lane_mode = torch.cat( + [GT_lane_mode, (GT_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, GT_homotopy = identify_pairwise_homotopy(fut_xy, mask=fut_mask) + GT_homotopy = GT_homotopy.type(torch.int64).reshape(B, N, N) + else: + GT_lane_mode = None + GT_homotopy = None + center_from_agents = parsed_batch["center_from_agents"] + if run_mode == RunMode.INFER: + num_samples = kwargs.get("num_samples", 10) + else: + num_samples = None + vars, mode_pred = self.nets["policy"]( + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type=agent_type, + enc_edges=enc_edges, + GT_lane_mode=GT_lane_mode, + GT_homotopy=GT_homotopy, + center_from_agents=center_from_agents, + num_samples=num_samples, + ) + output_keys = [ + "trajectories", + "inputs", + "states", + "input_violation", + "jerk", + "type_mask", + ] + output = {k: v for k, v in vars.items() if k in output_keys} + if run_mode == RunMode.INFER: + B, num_samples, N, Tf = output["trajectories"].shape[:-1] + log_pis = ( + mode_pred["joint_pred"] + .view(1, B, 1, num_samples) + .expand(N, B, Tf, num_samples) + ) + if "states" in output: + state_xyhv = torch.cat( + [ + output["states"][..., :2], + output["states"][..., 3:4], + output["states"][..., 2:3], + ], + -1, + ) + else: + # calculate velocity from trajectories + pred_vel = self.unicycle_model.calculate_vel( + output["trajectories"][..., :2], + output["trajectories"][..., 2:], + hist_mask.any(-1)[:, None, :, None] + .repeat_interleave(num_samples, 1) + .repeat_interleave(Tf, -1), + ) + state_xyhv = torch.cat([output["trajectories"], pred_vel], -1) + mus_xyhv = state_xyhv.permute(2, 0, 3, 1, 4) + log_sigmas = torch.zeros_like(mus_xyhv) + corrs = torch.zeros_like(mus_xyhv[..., 0]) + pred_dist = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) + output["pred_dist"] = pred_dist + output["pred_ml"] = state_xyhv + # output["pred_single"] = torch.full([B,0,Tf,4],np.nan,device=device) + output["sample_modes"] = dict( + lane_mode=mode_pred["lane_mode_sample"], + homotopy=mode_pred["homotopy_sample"], + ) + # print(output["inputs"][...,1].abs().max()) + else: + if "trajectories" in output: + # pad trajectories to Tf in case of anomaly + if output["trajectories"].shape[-2] != Tf: + breakpoint() + dynamic_outputs = { + k: v + for k, v in output.items() + if k + in [ + "inputs", + "states", + "input_violation", + "jerk", + "trajectories", + ] + } + if output["trajectories"].shape[-2] < Tf: + func = ( + lambda x: x + if x.shape[-2] == Tf + else torch.cat( + [ + x, + x[..., -1:, :].repeat_interleave( + Tf - x.shape[-2], -2 + ), + ], + -2, + ) + ) + elif output["trajectories"].shape[-2] > Tf: + func = lambda x: x[..., :Tf, :] + dynamic_outputs = TensorUtils.recursive_dict_list_tuple_apply( + dynamic_outputs, + { + torch.Tensor: func, + type(None): lambda x: x, + }, + ) + output.update(dynamic_outputs) + dec_xysc = torch.cat( + [ + output["trajectories"][..., :2], + torch.sin(output["trajectories"][..., 2:3]), + torch.cos(output["trajectories"][..., 2:3]), + ], + -1, + ) + if dec_xysc.shape[-2] != Tf: + pass + else: + if dec_xysc.ndim == 4: + DS = 1 + dec_xysc = dec_xysc.unsqueeze(1) + elif dec_xysc.ndim == 5: + DS = dec_xysc.size(1) # decode sample + + ( + dec_lm, + dec_lm_margin, + ) = self.fut_lane_relation.categorize_lane_relation_pts( + dec_xysc[:, :, :, -1].view(B * DS * N, 1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + fut_mask.repeat_interleave(DS, 0) + .any(-1) + .reshape(B * DS * N, 1), + lane_mask.repeat_interleave(N * DS, 0), + force_select=False, + force_unique=False, + const_override=dict(Y_near_thresh=1.0, X_rear_thresh=5.0), + ) + + y_dev, psi_dev = LaneUtils.get_ypsi_dev( + dec_xysc.view(B * DS * N, -1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + ) + + dec_angle_diff, dec_homotopy = identify_pairwise_homotopy( + dec_xysc[..., :2].reshape(B * DS, N, Tf, -1), + mask=fut_mask.any(-1).repeat_interleave(DS, 0), + ) + dec_homotopy_margin = torch.stack( + [ + HOMOTOPY_THRESHOLD - dec_angle_diff.abs(), + -dec_angle_diff - HOMOTOPY_THRESHOLD, + dec_angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + output["dec_lm_margin"] = dec_lm_margin.view(B, DS, N, M, -1) + output["dec_lane_y_dev"] = y_dev.view(B, DS, N, M, Tf, -1) + output["dec_lane_psi_dev"] = psi_dev.view(B, DS, N, M, Tf, -1) + output["dec_homotopy_margin"] = dec_homotopy_margin.view( + B, DS, N, N, -1 + ) + output["dec_lm"] = dec_lm.view(B, DS, N, M, -1) + output["dec_homotopy"] = dec_homotopy.view(B, DS, N, N, -1) + + output.update(mode_pred) + output["GT_lane_mode"] = GT_lane_mode + output["GT_homotopy"] = GT_homotopy + output["mode_valid_flag"] = mode_valid_flag + output["batch"] = parsed_batch + output["step_time"] = self.cfg.step_time + return output + + def compute_metrics(self, pred_batch, data_batch): + EPS = 1e-3 + metrics = dict() + if "GT_homotopy" in pred_batch: + # train/validation mode + batch = pred_batch["batch"] + agent_fut, fut_mask, pred_traj = TensorUtils.to_numpy( + (batch["agent_fut"], batch["fut_mask"], pred_batch["trajectories"]) + ) + if pred_traj.shape[-2] != agent_fut.shape[-2]: + return metrics + mode_valid_flag = pred_batch["mode_valid_flag"] + a2a_valid_flag = mode_valid_flag.unsqueeze(-1) * mode_valid_flag.unsqueeze( + -2 + ) + + metrics["joint_mode_accuracy"] = TensorUtils.to_numpy( + torch.softmax(pred_batch["joint_pred"], dim=1)[:, 0].mean() + ).item() + metrics["joint_mode_correct_rate"] = TensorUtils.to_numpy( + ( + pred_batch["joint_pred"][:, 0] + == pred_batch["joint_pred"].max(dim=1)[0] + ) + .float() + .mean() + ).item() + + lane_mode_correct_flag = pred_batch["GT_lane_mode"].argmax(-1) == ( + pred_batch["lane_mode_pred"].argmax(-1).squeeze(-1) + ) + metrics["pred_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (lane_mode_correct_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + metrics["lane_mode_accuracy"] = TensorUtils.to_numpy( + ( + ( + torch.softmax(pred_batch["lane_mode_pred"], dim=-1).squeeze(-2) + * pred_batch["GT_lane_mode"] + ).sum(-1) + * mode_valid_flag + ).sum() + / mode_valid_flag.sum() + ).item() + + homotopy_correct_flag = pred_batch["GT_homotopy"] == ( + pred_batch["homotopy_pred"].argmax(-1) + ) + + metrics["pred_homotopy_correct_rate"] = TensorUtils.to_numpy( + (homotopy_correct_flag * a2a_valid_flag).sum() / a2a_valid_flag.sum() + ).item() + metrics["homotopy_accuracy"] = TensorUtils.to_numpy( + ( + ( + torch.softmax(pred_batch["homotopy_pred"], dim=-1) + * F.one_hot(pred_batch["GT_homotopy"], len(HomotopyType)) + ).sum(-1) + * a2a_valid_flag + ).sum() + / a2a_valid_flag.sum() + ).item() + + B, N, Tf = agent_fut.shape[:3] + extent = batch["extent"] + traj = pred_batch["trajectories"] + DS = traj.size(1) + GT_homotopy = pred_batch["GT_homotopy"] + GT_lane_mode = pred_batch["GT_lane_mode"] + pred_xysc = torch.cat( + [traj[..., :2], torch.sin(traj[..., 2:3]), torch.cos(traj[..., 2:3])], + -1, + ).view(B, -1, N, Tf, 4) + end_points = pred_xysc[:, :, :, -1] # Only look at final time + lane_xyh = batch["lane_xyh"] + M = lane_xyh.size(1) + lane_xysc = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + pred_lane_mode, _ = self.fut_lane_relation.categorize_lane_relation_pts( + end_points.reshape(B * N * DS, 1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + batch["fut_mask"] + .any(-1) + .repeat_interleave(DS, 0) + .reshape(B * DS * N, 1), + batch["lane_mask"].repeat_interleave(DS * N, 0), + force_select=False, + ) + # You could have two lanes that it is both on + + pred_lane_mode = pred_lane_mode.squeeze(-2).argmax(-1).reshape(B, DS, N, M) + pred_lane_mode = torch.cat( + [pred_lane_mode, (pred_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, pred_homotopy = identify_pairwise_homotopy( + pred_xysc[..., :2].view(B * DS, N, Tf, 2), + mask=batch["fut_mask"].repeat_interleave(DS, 0), + ) + pred_homotopy = pred_homotopy.type(torch.int64).reshape(B, DS, N, N) + ML_homotopy_flag = (pred_homotopy[:, 1] == GT_homotopy).all(-1) + + ML_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + metrics["ML_homotopy_correct_rate"] = TensorUtils.to_numpy( + (ML_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_homotopy_flag = ( + (pred_homotopy[:, 1:] == GT_homotopy[:, None]).any(1).all(-1) + ) + all_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + metrics["all_homotopy_correct_rate"] = TensorUtils.to_numpy( + (all_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_lane_mode_flag = (pred_lane_mode[:, 1] == GT_lane_mode).all(-1) + all_lane_mode_flag = ( + (pred_lane_mode[:, 1:] == GT_lane_mode[:, None]).any(1).all(-1) + ) + metrics["ML_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + metrics["all_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (all_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_scene_mode_flag = (pred_homotopy[:, 1] == GT_homotopy).all(-1) & ( + pred_lane_mode[:, 1] == GT_lane_mode + ).all(-1) + metrics["ML_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_scene_mode_flag = ( + (pred_homotopy[:, 1:] == GT_homotopy[:, None]).all(-1) + & (pred_lane_mode[:, 1:] == GT_lane_mode[:, None]).all(-1) + ).any(1) + metrics["all_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (all_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + pred_edges = self.bu.generate_edges( + batch["agent_type"].repeat_interleave(DS, 0), + extent.repeat_interleave(DS, 0), + TensorUtils.join_dimensions(traj[..., :2], 0, 2), + TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), + batch_first=True, + ) + pred_edges = {k: v for k, v in pred_edges.items() if k != "PP"} + + coll_loss, dis = collision_loss(pred_edges=pred_edges, return_dis=True) + dis_padded = { + k: v.nan_to_num(1.0).view(B, DS, -1, Tf) for k, v in dis.items() + } + edge_mask = { + k: torch.logical_not(v.isnan().any(-1)).view(B, DS, -1) + for k, v in dis.items() + } + for k in dis_padded: + metrics["collision_rate_ML_" + k] = TensorUtils.to_numpy( + (dis_padded[k][:, 1] < 0).sum() + / (edge_mask[k][:, 1].sum() + EPS) + / Tf + ).item() + metrics[f"collision_rate_all_{DS}_mode_{k}"] = TensorUtils.to_numpy( + (dis_padded[k] < 0).sum() / (edge_mask[k].sum() + EPS) / Tf + ).item() + metrics["collision_rate_ML"] = TensorUtils.to_numpy( + sum([(v[:, 1] < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k][:, 1].sum() for k in edge_mask]) + EPS) + ).item() + metrics[f"collision_rate_all_{DS}_mode"] = TensorUtils.to_numpy( + sum([(v < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k].sum() for k in edge_mask]) + EPS) + ).item() + + confidence = np.ones([B * N, 1]) + Nmode = pred_traj.shape[1] + agent_type = TensorUtils.to_numpy(batch["agent_type"]) + vehicle_mask = ( + (agent_type == AgentType.VEHICLE) + | (agent_type == AgentType.BICYCLE) + | (agent_type == AgentType.MOTORCYCLE) + ) + pedestrian_mask = agent_type == AgentType.PEDESTRIAN + agent_mask = fut_mask.any(-1) + dt = self.cfg.step_time + for Tsecond in [3.0, 4.0, 5.0, 6.0, 8.0]: + if Tf < Tsecond / dt: + continue + Tf_bar = int(Tsecond / dt) + # oracle mode means the trajectory decoded under the GT scene mode + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + metrics[f"oracle_ADE@{Tsecond}s"] = allADE + metrics[f"oracle_FDE@{Tsecond}s"] = allFDE + metrics[f"oracle_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"oracle_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"oracle_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"oracle_pedestrian_FDE@{Tsecond}s"] = pedFDE + + # the second mode is the most likely mode (first is GT) + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + + metrics[f"ML_ADE@{Tsecond}s"] = allADE + metrics[f"ML_FDE@{Tsecond}s"] = allFDE + metrics[f"ML_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"ML_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"ML_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"ML_pedestrian_FDE@{Tsecond}s"] = pedFDE + + # minADE and minFDE are calculated excluding the oracle mode + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1:, ..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode - 1, 1) / (Nmode - 1), + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1:, ..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode - 1, 1) / (Nmode - 1), + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + + metrics[f"min_ADE@{Tsecond}s"] = allADE + metrics[f"min_FDE@{Tsecond}s"] = allFDE + metrics[f"min_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"min_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"min_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"min_pedestrian_FDE@{Tsecond}s"] = pedFDE + + if "dec_lm" in pred_batch: + DS = pred_batch["dec_lm"].shape[1] + M = batch["lane_xyh"].shape[-2] + if self.cfg.classify_a2l_4all_lanes: + dec_lm, dec_cond_lm, lane_mask = TensorUtils.to_numpy( + ( + pred_batch["dec_lm"], + F.one_hot( + pred_batch["dec_cond_lm"][..., :M], + len(self.fut_lane_relation), + ), + batch["lane_mask"], + ) + ) + mask = ( + (agent_mask[:, :, None] * lane_mask[:, None]) + .reshape(B, 1, N, M) + .repeat(DS, 1) + ) + lm_consistent_rate = ((dec_lm == dec_cond_lm).all(-1) * mask).sum() / ( + (mask).sum() + EPS + ) + else: + dec_lm, dec_cond_lm, lane_mask = TensorUtils.to_numpy( + ( + pred_batch["dec_lm"], + pred_batch["dec_cond_lm"], + batch["lane_mask"], + ) + ) + dec_lm = dec_lm.argmax(-1) + + if self.null_lane_mode: + # augment the null lane + dec_lm = np.concatenate( + [dec_lm, (dec_lm == 0).all(-1, keepdims=True)], -1 + ) + lane_mask = np.concatenate( + [lane_mask, np.ones([B, 1], dtype=bool)], -1 + ) + mask = ( + (agent_mask[:, :, None] * lane_mask[:, None]) + .reshape(B, 1, N, -1) + .repeat(DS, 1) + ) + lm_consistent_rate = ((dec_lm == dec_cond_lm) * mask).sum() / ( + (mask).sum() + EPS + ) + + metrics["lm_consistent_rate"] = lm_consistent_rate + if "dec_homotopy" in pred_batch: + dec_homotopy, dec_cond_homotopy = TensorUtils.to_numpy( + (pred_batch["dec_homotopy"], pred_batch["dec_cond_homotopy"]) + ) + DS = dec_homotopy.shape[1] + mask = ( + (agent_mask[:, :, None] * agent_mask[:, None]) + .reshape(B, 1, N, N) + .repeat(DS, 1) + ) + + homo_consistent_rate = ( + (dec_homotopy.squeeze(-1) == dec_cond_homotopy) * mask + ).sum() / ((mask).sum() + EPS) + metrics["homo_consistent_rate"] = homo_consistent_rate + + return metrics + + def compute_losses(self, pred_batch, inputs): + EPS = 1e-3 + batch = pred_batch["batch"] + lm_K = len(self.fut_lane_relation) + homo_K = len(HomotopyType) + GT_homotopy = F.one_hot(pred_batch["GT_homotopy"], homo_K).float() + GT_lane_mode = F.one_hot(pred_batch["GT_lane_mode"], lm_K).float() + mode_valid_flag = pred_batch["mode_valid_flag"] + dt = self.cfg.step_time + Tf = batch["agent_fut"].shape[-2] + if not self.cfg.classify_a2l_4all_lanes: + # We perform softmax (in cross entropy) over the lane segments + mask_noton_mode = torch.ones( + len(self.fut_lane_relation), dtype=torch.bool, device=self.device + ) + mask_noton_mode[self.fut_lane_relation.NOTON] = False + DS = pred_batch["trajectories"].size(1) + future_mask = batch["fut_mask"] + agent_mask = batch["hist_mask"].any(-1) + # mode probability loss + + B, N, Mext = pred_batch["GT_lane_mode"].shape[:3] + if self.null_lane_mode: + M = Mext - 1 + else: + M = Mext + # marginal mode prediction probability + lane_mode_pred = pred_batch["lane_mode_pred"] + homotopy_pred = pred_batch["homotopy_pred"] + # marginal probability loss, use "subtract max" trick for stable cross entropy + + B, N = agent_mask.shape + if not self.cfg.classify_a2l_4all_lanes: + # We perform softmax (in cross entropy) over the lane segments + GT_lane_mode_marginal = torch.masked_select( + GT_lane_mode, mask_noton_mode[None, None, None, :] + ).view(*GT_lane_mode.shape[:3], -1) + GT_lane_mode_marginal = GT_lane_mode_marginal.swapaxes( + -1, -2 + ) # Not onehot anymore + + lane_mode_pred = lane_mode_pred.reshape(-1, Mext) + marginal_lm_loss = self.lane_mode_loss( + lane_mode_pred - lane_mode_pred.max(dim=-1, keepdim=True)[0], + GT_lane_mode_marginal.reshape(-1, Mext), + ) + lm_mask = mode_valid_flag.flatten().repeat_interleave(lm_K - 1, 0) + else: + lane_mode_pred = lane_mode_pred.reshape(-1, lm_K) + marginal_lm_loss = self.lane_mode_loss( + lane_mode_pred - lane_mode_pred.max(dim=-1, keepdim=True)[0], + GT_lane_mode.reshape(-1, lm_K), + ) + lm_mask = mode_valid_flag.flatten().repeat_interleave(M, 0) + raise NotImplementedError + + marginal_homo_loss = self.homotopy_loss( + homotopy_pred.reshape(-1, homo_K) + - homotopy_pred.reshape(-1, homo_K).max(dim=-1, keepdim=True)[0], + GT_homotopy.reshape(-1, homo_K), + ) + + homo_mask = ( + mode_valid_flag.unsqueeze(2) * mode_valid_flag.unsqueeze(1) + ).flatten() + marginal_lm_loss = (marginal_lm_loss * lm_mask).sum() / (lm_mask.sum() + EPS) + marginal_homo_loss = (marginal_homo_loss * homo_mask).sum() / ( + homo_mask.sum() + EPS + ) + + # joint probability + joint_logpi = pred_batch["joint_pred"] + GT_joint_mode = torch.zeros_like(joint_logpi) + GT_joint_mode[ + ..., 0 + ] = 1 # FIXME: Shouldn't this be only if the first element is the GT? + joint_prob_loss = self.joint_mode_loss( + (joint_logpi - joint_logpi.max(dim=-1, keepdim=True)[0]).nan_to_num(0), + GT_joint_mode, + ).mean() + + # decoded trajectory consistency loss + dec_normalized_prob = pred_batch["dec_cond_prob"] / pred_batch[ + "dec_cond_prob" + ].sum(-1, keepdim=True) + + if "dec_lm_margin" in pred_batch: + mask = (agent_mask.unsqueeze(2) * batch["lane_mask"].unsqueeze(1)).view( + B, 1, N, M, 1 + ) + if self.null_lane_mode: + mask = torch.cat([mask, torch.ones_like(mask[:, :, :, :1])], -2) + noton_mask = torch.ones(len(self.fut_lane_relation), device=mask.device) + noton_mask[self.fut_lane_relation.NOTON] = 0 + mask = mask * noton_mask + if self.null_lane_mode: + # the margin for null lane is the inverse of the maximum of all lane margins + null_lm_margin = ( + -pred_batch["dec_lm_margin"].max(-2)[0] + * (1 - noton_mask)[None, None, None, :] + + -pred_batch["dec_lm_margin"].min(-2)[0] + * noton_mask[None, None, None, :] + ) + dec_lm_margin = torch.cat( + [pred_batch["dec_lm_margin"], null_lm_margin.unsqueeze(-2)], -2 + ) + else: + dec_lm_margin = pred_batch["dec_lm_margin"] - self.cfg.lm_margin_offset + + if self.weighted_consistency_loss: + lm_consistency_loss = ( + loss_clip( + F.relu( + -dec_lm_margin + * pred_batch["dec_cond_lm"].unsqueeze(-1) + * mask + ), + max_loss=4.0, + ).sum(-1) + * dec_normalized_prob[..., None, None] + ).sum() / B # FIXME: should this now be over M? (with self.cfg.classify_a2l_4all_lanes) + else: + lm_consistency_loss = ( + ( + loss_clip( + F.relu( + -dec_lm_margin + * pred_batch["dec_cond_lm"].unsqueeze(-1) + * mask + ), + max_loss=4.0, + ).sum(-1) + ).sum() + / B + / DS + ) + + else: + lm_consistency_loss = torch.tensor(0.0, device=self.device) + if "dec_lane_y_dev" in pred_batch: + mask = (agent_mask.unsqueeze(2) * batch["lane_mask"].unsqueeze(1)).view( + B, 1, N, M + ) + time_weight = torch.arange(1, Tf + 1, device=self.device) / Tf**2 + psi_dev = pred_batch["dec_lane_psi_dev"].squeeze(-1) + y_dev = pred_batch["dec_lane_y_dev"].squeeze(-1) + yaw_dev_loss = ( + psi_dev.abs() + * mask[..., :M, None] + * pred_batch["dec_cond_lm"][..., :M, None] + * time_weight[None, None, None, None, :] + ).sum() / (mask[..., :M].sum() + EPS) + y_dev_loss = ( + y_dev.abs() + * mask[..., :M, None] + * pred_batch["dec_cond_lm"][..., :M, None] + * time_weight[None, None, None, None, :] + ).sum() / (mask[..., :M].sum() + EPS) + else: + yaw_dev_loss = torch.tensor(0.0, device=self.device) + y_dev_loss = torch.tensor(0.0, device=self.device) + if "dec_homotopy_margin" in pred_batch: + mask = (agent_mask.unsqueeze(2) * agent_mask.unsqueeze(1))[ + :, None, :, :, None + ] + if self.weighted_consistency_loss: + homotopy_consistency_loss = ( + loss_clip( + F.relu( + -pred_batch["dec_homotopy_margin"] + * F.one_hot( + pred_batch["dec_cond_homotopy"], len(HomotopyType) + ) + * mask + ), + max_loss=4.0, + ).sum(-1) + * dec_normalized_prob[..., None, None] + ).sum() / B + else: + homotopy_consistency_loss = ( + ( + loss_clip( + F.relu( + -pred_batch["dec_homotopy_margin"] + * F.one_hot( + pred_batch["dec_cond_homotopy"], + len(HomotopyType), + ) + * mask + ), + max_loss=4.0, + ).sum(-1) + ).sum() + / B + / DS + ) + else: + homotopy_consistency_loss = torch.tensor(0.0, device=self.device) + + # trajectory reconstruction loss + traj = pred_batch["trajectories"] + if traj.shape[-2] != batch["agent_fut"].shape[-2]: + xy_loss = torch.tensor(0.0, device=traj.device) + yaw_loss = torch.tensor(0.0, device=traj.device) + coll_loss = torch.tensor(0.0, device=traj.device) + acce_reg_loss = torch.tensor(0.0, device=traj.device) + steering_reg_loss = torch.tensor(0.0, device=traj.device) + input_violation_loss = torch.tensor(0.0, device=traj.device) + jerk_loss = torch.tensor(0.0, device=traj.device) + else: + # only penalize the trajectory under GT mode + traj_GT_mode = pred_batch["trajectories"][:, 0] + GT_xy = batch["agent_fut"][..., :2] + GT_h = ratan2(batch["agent_fut"][..., 6:7], batch["agent_fut"][..., 7:8]) + + xy_loss = self.traj_loss(GT_xy, traj_GT_mode[..., :2]) + xy_loss = xy_loss.norm(dim=-1) + xy_loss = (xy_loss * future_mask).sum() / (future_mask.sum() + EPS) + yaw_loss = torch.abs( + GeoUtils.round_2pi(traj_GT_mode[..., 2:3] - GT_h) + ).squeeze(-1) + yaw_loss = (yaw_loss * future_mask).sum() / (future_mask.sum() + EPS) + + # collision loss + extent = batch["extent"] + DS = traj.size(1) + pred_edges = self.bu.generate_edges( + batch["agent_type"].repeat_interleave(DS, 0), + extent.repeat_interleave(DS, 0), + TensorUtils.join_dimensions(traj[..., :2], 0, 2), + TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), + batch_first=True, + ) + + edge_weight = pred_batch["dec_cond_prob"].flatten().view(-1, 1) + + coll_loss = collision_loss( + pred_edges={k: v for k, v in pred_edges.items() if k != "PP"}, + weight=edge_weight, + ).nan_to_num(0) + # diff = torch.norm((traj[:,1,...,:2]-traj[:,2,...,:2])*agent_mask[...,None,None])+torch.norm((traj[:,2,...,:2]-traj[:,3,...,:2])*agent_mask[...,None,None])+torch.norm((traj[:,3,...,:2]-traj[:,4,...,:2])*agent_mask[...,None,None]) + # print(diff) + acce_reg_loss = torch.tensor(0.0, device=traj.device) + steering_reg_loss = torch.tensor(0.0, device=traj.device) + input_violation_loss = torch.tensor(0.0, device=traj.device) + jerk_loss = torch.tensor(0.0, device=traj.device) + + type_mask = pred_batch["type_mask"] + inputs = pred_batch["inputs"] + input_violation = ( + pred_batch["input_violation"] + if "input_violation" in pred_batch + else dict() + ) + jerk = pred_batch["jerk"] if "jerk" in pred_batch else dict() + + for k, dyn in self.nets["policy"].dyn.items(): + if type_mask[k].sum() == 0: + continue + if isinstance(dyn, Unicycle) or isinstance(dyn, Unicycle_xyvsc): + acce_reg_loss += ( + inputs[k][..., 0:1].norm(dim=-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + steering_reg_loss += ( + inputs[k][..., 1:2].norm(dim=-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + if k in input_violation: + input_violation_loss += ( + input_violation[k].sum(-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + if k in jerk: + jerk_loss += ( + (jerk[k] ** 2).sum(-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + losses = dict( + marginal_lm_loss=marginal_lm_loss, + marginal_homo_loss=marginal_homo_loss, + l2_reg=lane_mode_pred.nan_to_num(0).norm(dim=-1).mean() + + homotopy_pred.nan_to_num(0).norm(dim=-1).mean() + + joint_logpi.nan_to_num(0).abs().mean(), + joint_prob_loss=joint_prob_loss, + lm_consistency_loss=lm_consistency_loss, + homotopy_consistency_loss=homotopy_consistency_loss, + yaw_dev_loss=yaw_dev_loss, + y_dev_loss=y_dev_loss, + xy_loss=xy_loss, + yaw_loss=yaw_loss, + coll_loss=coll_loss, + acce_reg_loss=acce_reg_loss, + steering_reg_loss=steering_reg_loss, + input_violation_loss=input_violation_loss, + jerk_loss=jerk_loss, + ) + for k, v in losses.items(): + if v.isinf() or v.isnan(): + raise ValueError(f"{k} becomes NaN") + return losses + + def log_pred_image( + self, + batch, + pred_batch, + batch_idx, + logger, + log_all_image=False, + savegif=False, + **kwargs, + ): + if batch_idx == 0: + return + + try: + parsed_batch = pred_batch["batch"] + indices = ( + list(range(parsed_batch["hist_mask"].shape[0])) + if log_all_image + else [parsed_batch["hist_mask"][..., -1].sum(-1).argmax().item()] + ) + world_from_agent = TensorUtils.to_numpy(parsed_batch["world_from_agent"]) + center_xy = TensorUtils.to_numpy( + parsed_batch["centered_agent_state"][:, :2] + ) + centered_agent_state = TensorUtils.to_numpy( + parsed_batch["centered_agent_state"] + ) + world_yaw = np.arctan2( + centered_agent_state[:, -2], centered_agent_state[:, -1] + ) + curr_posyaw = torch.cat( + [ + parsed_batch["agent_hist"][..., :2], + torch.arctan2( + parsed_batch["agent_hist"][..., -2], + parsed_batch["agent_hist"][..., -1], + )[..., None], + ], + -1, + ) + NS = pred_batch["trajectories"].shape[1] + traj = torch.cat( + [ + curr_posyaw[:, None, :, -1:].repeat_interleave(NS, 1), + pred_batch["trajectories"], + ], + -2, + ) + traj = TensorUtils.to_numpy(traj) + world_traj_xy = GeoUtils.batch_nd_transform_points_np( + traj[..., :2], world_from_agent[:, None, None] + ) + world_traj_yaw = traj[..., 2] + world_yaw[:, None, None, None] + world_traj = np.concatenate([world_traj_xy, world_traj_yaw[..., None]], -1) + + extent = TensorUtils.to_numpy(parsed_batch["agent_hist_extent"][:, :, -1]) + hist_mask = TensorUtils.to_numpy(parsed_batch["hist_mask"]) + # pool = mp.Pool(len(indices)) + + # def plot_scene(bi): + for bi in indices: + if isinstance(logger, str) or isinstance(logger, Path): + # directly save file + html_file_name = ( + Path(logger) / f"CTT_visualization_{batch_idx}_{bi}.html" + ) + else: + html_file_name = f"CTT_visualization_{bi}.html" + if os.path.exists(html_file_name): + os.remove(html_file_name) + + # plot agent + for mode in range(min(NS, 3)): + output_file(html_file_name) + graph = figure(title="Bokeh graph", width=900, height=900) + graph.xgrid.grid_line_color = None + graph.ygrid.grid_line_color = None + + graph.x_range = Range1d( + center_xy[bi, 0] - 40, center_xy[bi, 0] + 40 + ) + graph.y_range = Range1d( + center_xy[bi, 1] - 30, center_xy[bi, 1] + 50 + ) + plot_scene_open_loop( + graph, + world_traj[bi, mode], + extent[bi], + parsed_batch["vector_maps"][bi], + np.eye(3), + bbox=[ + center_xy[bi, 0] - 40, + center_xy[bi, 0] + 40, + center_xy[bi, 1] - 30, + center_xy[bi, 1] + 50, + ], + mask=hist_mask[bi].any(-1), + color_scheme="palette", + ) + if isinstance(logger, WandbLogger): + save(graph) + wandb_html = wandb.Html(open(html_file_name)) + logger.experiment.log( + {f"val_pred_image_mode{mode}": wandb_html} + ) + + elif isinstance(logger, str) or isinstance(logger, Path): + html_file_name = ( + Path(logger) + / f"CTT_visualization_{batch_idx}_{bi}_mode{mode}.html" + ) + output_file(html_file_name) + save(graph) + png_name = ( + str(html_file_name).removesuffix(".html") + + f"_mode{mode}.png" + ) + export_png(graph, filename=png_name) + del graph + if savegif: + graph = figure(title="Bokeh graph", width=900, height=900) + graph.xgrid.grid_line_color = None + graph.ygrid.grid_line_color = None + + graph.x_range = Range1d( + center_xy[bi, 0] - 40, center_xy[bi, 0] + 40 + ) + graph.y_range = Range1d( + center_xy[bi, 1] - 30, center_xy[bi, 1] + 50 + ) + gif_name = ( + str(html_file_name).removesuffix(".html") + + f"_mode{mode}.gif" + ) + animate_scene_open_loop( + graph, + world_traj[bi, mode], + extent[bi], + parsed_batch["vector_maps"][bi], + np.eye(3), + bbox=[ + center_xy[bi, 0] - 40, + center_xy[bi, 0] + 40, + center_xy[bi, 1] - 30, + center_xy[bi, 1] + 50, + ], + mask=hist_mask[bi].any(-1), + color_scheme="palette", + dt=parsed_batch["dt"][0].item(), + gif_name=gif_name, + # tmp_dir=str(html_file_name).removesuffix(".html"), + ) + del graph + + # pool.map(plot_scene, indices) + except Exception as e: + print(e) + pass diff --git a/diffstack/modules/predictors/__init__.py b/diffstack/modules/predictors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/modules/predictors/factory.py b/diffstack/modules/predictors/factory.py new file mode 100644 index 0000000..d3e87db --- /dev/null +++ b/diffstack/modules/predictors/factory.py @@ -0,0 +1,69 @@ +"""Factory methods for creating planner""" +from diffstack.configs.base import AlgoConfig + +from diffstack.utils.utils import removeprefix + +from diffstack.modules.predictors.kinematic_predictor import KinematicTreeModel +from diffstack.modules.predictors.tbsim_predictors import ( + AgentFormerTrafficModel, +) + +from diffstack.modules.predictors.CTT import CTTTrafficModel + + +def predictor_factory( + model_registrar, + config: AlgoConfig, + logger, + device, + input_mappings={}, + checkpoint=None, +): + """ + A factory for creating predictor modules + + Args: + config (AlgoConfig): an AlgoConfig object, + + Returns: + predictor: predictor module + """ + algo_name = config.name + + if algo_name == "kinematic": + predictor = KinematicTreeModel( + model_registrar, config, logger, device, input_mappings=input_mappings + ) + + elif algo_name in [ + "agentformer_multistage", + "agentformer_singlestage", + "agentformer", + ]: + predictor = AgentFormerTrafficModel( + model_registrar, config, logger, device, input_mappings=input_mappings + ) + elif algo_name == "CTT": + predictor = CTTTrafficModel( + model_registrar, config, logger, device, input_mappings=input_mappings + ) + else: + raise NotImplementedError(f"{algo_name} is not implemented") + + if checkpoint is not None: + if "state_dict" in checkpoint: + predictor_dict = { + removeprefix(k, "components.predictor."): v + for k, v in checkpoint["state_dict"].items() + if k.startswith("components.predictor.") + } + elif "model_dict" in checkpoint: + predictor_dict = { + "model." + k: v for k, v in checkpoint["model_dict"].items() + } + # Are we loading all model parameters? + assert all([k in predictor_dict for k in predictor.state_dict().keys()]) + else: + raise ValueError("Unknown checkpoint format") + predictor.load_state_dict(predictor_dict) + return predictor diff --git a/diffstack/modules/predictors/kinematic_predictor.py b/diffstack/modules/predictors/kinematic_predictor.py new file mode 100644 index 0000000..df772a1 --- /dev/null +++ b/diffstack/modules/predictors/kinematic_predictor.py @@ -0,0 +1,163 @@ +import torch +import numpy as np +from diffstack.dynamics.unicycle import Unicycle +from diffstack.modules.module import Module, DataFormat +from diffstack.utils.batch_utils import batch_utils +from trajdata.data_structures.batch import AgentBatch,SceneBatch +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +import diffstack.utils.tensor_utils as TensorUtils +from typing import Dict, Optional, Union, Any + + +class KinematicTreeModel(Module): + + @property + def input_format(self) -> DataFormat: + if self.scene_centric: + return DataFormat(["scene_batch"]) + else: + return DataFormat(["agent_batch"]) + + @property + def output_format(self) -> DataFormat: + return DataFormat(["pred_dist", "pred_ml", "pred_single", "metrics"]) + + def __init__(self,model_registrar, config, log_writer, device, input_mappings={}): + super().__init__(model_registrar, config, log_writer, device, input_mappings) + self.config = config + self.step_time = config.step_time + if "dynamics" in config: + if config.dynamics.type=="Unicycle": + self.dyn = Unicycle(self.step_time,max_steer=config.dynamics.max_steer,max_yawvel=config.dynamics.max_yawvel,acce_bound=config.dynamics.acce_bound) + else: + raise NotImplementedError + else: + self.dyn=Unicycle(self.step_time) + + self.stage=config["stage"] + self.num_frames_per_stage = config["num_frames_per_stage"] + self.M = config["M"] + self.scene_centric = config["scene_centric"] + self.only_branch_closest = config.only_branch_closest if "only_branch_closest" in config else self.scene_centric + self.bu = batch_utils(rasterize_mode="none",parse=True) + + def train_step(self, inputs: Dict) -> Dict: + outputs = self.infer_step(inputs) + outputs["loss"] = torch.zeros((), device=self.device) + return outputs + + def validate_step(self, inputs: Dict) -> Dict: + outputs = self.infer_step(inputs) + outputs["metrics"] = {} + return outputs + + def infer_step(self, inputs: Dict): + if self.scene_centric: + batch = inputs["scene_batch"] + else: + batch = inputs["agent_batch"] + if isinstance(batch, dict) and "batch" in batch: + batch = batch ["batch"] + batch = self.bu.parse_batch(batch) + pred_ml, pred_dist = self.kinematic_predictions(batch) + + # Dummy single agent prediction. + agent_fut = batch["agent_fut"] + pred_single = torch.full([agent_fut.shape[0], 0, agent_fut.shape[2], 4], dtype=agent_fut.dtype, device=agent_fut.device, fill_value=torch.nan) + + output = dict(pred_dist=pred_dist, pred_ml=pred_ml, pred_single=pred_single) + if self.scene_centric: + output["scene_batch"]=batch + else: + output["agent_batch"]=batch + return output + + def kinematic_predictions(self,batch): + if batch["agent_hist"].shape[-1] == 7: # x, y, vx, vy, ax, ay, heading + yaw = batch["agent_hist"][...,6:7] + else: + assert batch["agent_hist"].shape[-1] == 8 # x, y, vx, vy, ax, ay, sin(heading), cos(heading) + yaw = torch.atan2(batch["agent_hist"][..., [-2]], batch["agent_hist"][..., [-1]]) + pos = batch["agent_hist"][...,:2] + speed = torch.norm(batch["agent_hist"][..., 2:4], dim=-1,keepdim=True) + curr_states = torch.cat([pos, yaw, speed], -1)[...,-1,:] + + bs = curr_states.shape[0] + + device = curr_states.device + assert self.M<=4 + if self.scene_centric: + Na = batch["agent_hist"].shape[1] + inputs = torch.zeros([bs,self.M,Na,2]).to(device) + if self.only_branch_closest: + if batch["agent_hist"].shape[1] == 1: + branch_idx = 0 + else: + dis = torch.norm(batch["agent_hist"][0,1:,-1,:2]-batch["agent_hist"][0,:1,-1,:2],dim=-1) + branch_idx = dis.argmin()+1 + else: + branch_idx = 1 if Na>1 else 0 + if self.M>1: + inputs[:,1,branch_idx,1]=-0.02*curr_states[:,branch_idx,2] + if self.M>2: + inputs[:,2,branch_idx,1]=0.02*curr_states[:,branch_idx,2] + if self.M>3: + inputs[:,3,branch_idx,0]=torch.clip(-curr_states[:,branch_idx,2]*0.5,min=-3) + + else: + inputs = torch.zeros([bs,self.M,2]).to(device) + if self.M>1: + inputs[:,1,1]=-0.02*curr_states[:,2] + if self.M>2: + inputs[:,2,1]=0.02*curr_states[:,2] + if self.M>3: + inputs[:,3,0]=torch.clip(-curr_states[:,2]*0.5,min=-3) + assert self.M<=4 + + + inputs = inputs.unsqueeze(-2).repeat_interleave(self.num_frames_per_stage,-2) + + T = self.num_frames_per_stage*self.stage + pred_by_stage = list() + for stage in range(self.stage): + inputs_i = inputs.repeat_interleave(self.M**stage,0) + if stage==0: + curr_state_i = curr_states + else: + curr_state_i = pred_by_stage[stage-1][...,-1,:] + pred_i = self.dyn.forward_dynamics(curr_state_i.repeat_interleave(self.M,0),TensorUtils.join_dimensions(inputs_i,0,2)) + pred_by_stage.append(pred_i) + + for stage in range(self.stage): + if self.scene_centric: + pred_by_stage[stage]=pred_by_stage[stage].reshape([bs,-1,Na,self.num_frames_per_stage,4]).repeat_interleave(self.M**(self.stage-1-stage),1) + else: + pred_by_stage[stage]=pred_by_stage[stage].reshape([bs,-1,self.num_frames_per_stage,4]).repeat_interleave(self.M**(self.stage-1-stage),1) + pred_traj = torch.cat(pred_by_stage,-2) # b x N x T x D + if self.scene_centric: + pred_traj = pred_traj.reshape(bs,self.M**self.stage,Na,T,-1) + prob = torch.ones([bs,self.M**self.stage]).to(device) # b x N + prob = prob/prob.sum(-1,keepdim=True) + if self.scene_centric: + mus = pred_traj[...,[0,1,2,3]].permute(2,0,3,1,4) # (Na, b, T, M**stage, 4) + Tf = pred_traj.shape[-2] + log_pis = torch.log(prob)[None,:,None,:].repeat_interleave(Tf,-2).repeat_interleave(Na,0) + log_sigmas = torch.log(torch.tensor((self.config['dt']*np.arange(Tf+1))[1:]**2*2, dtype=curr_states.dtype, device=curr_states.device)) + log_sigmas = log_sigmas.reshape(1, 1, Tf, 1, 1).repeat((Na, bs, 1, self.M**self.stage, 2)) + corrs = 0. * torch.ones((1, bs, Tf, self.M**self.stage), dtype=curr_states.dtype, device=curr_states.device) # TODO not sure what is reasonable + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + else: + mus = pred_traj[...,[0,1,2,3]].transpose(1,2).unsqueeze(0) # (1, b, T, M**stage, 4) + Tf = pred_traj.shape[-2] + log_pis = torch.log(prob)[None,:,None,:].repeat_interleave(Tf,-2) + log_sigmas = torch.log(torch.tensor((self.config['dt']*np.arange(Tf+1))[1:]**2*2, dtype=curr_states.dtype, device=curr_states.device)) + log_sigmas = log_sigmas.reshape(1, 1, Tf, 1, 1).repeat((1, bs, 1, self.M**self.stage, 2)) + corrs = 0. * torch.ones((1, bs, Tf, self.M**self.stage), dtype=curr_states.dtype, device=curr_states.device) # TODO not sure what is reasonable + + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + return pred_traj, y_dists + + def validation_metrics(self, pred_dist, pred_ml, agent_fut): + # Compute default metrics + metrics = compute_prediction_metrics(pred_ml, agent_fut[..., :2], y_dists=pred_dist) + return metrics \ No newline at end of file diff --git a/diffstack/modules/predictors/tbsim_predictors.py b/diffstack/modules/predictors/tbsim_predictors.py new file mode 100644 index 0000000..21a6c4e --- /dev/null +++ b/diffstack/modules/predictors/tbsim_predictors.py @@ -0,0 +1,477 @@ +import torch +import torch.nn as nn +import numpy as np +from diffstack.modules.module import Module, DataFormat, RunMode + +from diffstack.utils.utils import traj_xyh_to_xyhv, removeprefix +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.batch_utils import batch_utils +from diffstack.models.agentformer import AgentFormer +from trajdata.data_structures.batch import SceneBatch +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +import diffstack.utils.tensor_utils as TensorUtils +from typing import Dict, Any + + +from diffstack.utils.loss_utils import ( + collision_loss, +) + + +from trajdata.data_structures import StateTensor, AgentType + + +from diffstack.utils.lane_utils import SimpleLaneRelation +from diffstack.utils.homotopy import ( + identify_pairwise_homotopy, +) +import diffstack.utils.metrics as Metrics + + +class AgentFormerTrafficModel(Module): + @property + def input_format(self) -> DataFormat: + return DataFormat(["scene_batch"]) + + @property + def output_format(self) -> DataFormat: + return DataFormat( + [ + "mixed_pred_ml:validate", + "mixed_pred_ml:infer", + "metrics:train", + "metrics:validate", + "step_time", + ] + ) + + @property + def checkpoint_monitor_keys(self): + return {"valLoss": "val/losses_predictor_prediction_loss"} + + def __init__(self, model_registrar, cfg, log_writer, device, input_mappings={}): + super(AgentFormerTrafficModel, self).__init__( + model_registrar, cfg, log_writer, device, input_mappings=input_mappings + ) + + self.bu = batch_utils(rasterize_mode="none") + self.modality_shapes = self.bu.get_modality_shapes(cfg) + # assert modality_shapes["image"][0] == 15 + self.nets = nn.ModuleDict() + + self.bu = batch_utils(parse=True, rasterize_mode="none") + self.nets["policy"] = AgentFormer(cfg) + self.cfg = cfg + + if "checkpoint" in cfg and cfg.checkpoint["enabled"]: + checkpoint = torch.load(cfg.checkpoint["path"]) + predictor_dict = { + removeprefix(k, "components.predictor."): v + for k, v in checkpoint["state_dict"].items() + if k.startswith("components.predictor.") + } + self.load_state_dict(predictor_dict) + + self.device = device + self.nets["policy"] = self.nets["policy"].to(self.device) + + def set_eval(self): + self.nets["policy"].eval() + + def set_train(self): + self.nets["policy"].train() + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + if isinstance(inputs["scene_batch"], dict) and "batch" in inputs["scene_batch"]: + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"]["batch"].astype(torch.float) + ) + else: + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"].astype(torch.float) + ) + + # tic = torch_utils.tic(timer=self.hyperparams.debug.timer) + parsed_batch_torch = { + k: v.to(self.device) + for k, v in parsed_batch.items() + if isinstance(v, torch.Tensor) + } + parsed_batch.update(parsed_batch_torch) + output = self.nets["policy"]( + parsed_batch, predict=(run_mode == RunMode.INFER), **kwargs + ) + # torch_utils.toc(tic, name="prediction model", timer=self.hyperparams.debug.timer) + + # Convert to standard prediction output + trajs_xyh = output["trajectories"] # b, Ne, mode, N_agent, t, xyh + trajs_xyh = trajs_xyh[:, 0] # b, K, Na, T, 3 + trajs_xyh = trajs_xyh.transpose(1, 2) # b, Na, K, T, 3 + trajs_xyh = trajs_xyh[:, :, None] # # b, Na, S=1, K, T, 3 + # Infer velocity from xy + dt = self.hyperparams["step_time"] # hyperparams is AgentFormerConfig + trajs_xyhv = traj_xyh_to_xyhv(trajs_xyh, dt) + trajs_xyhv = StateTensor.from_array(trajs_xyhv, format="x,y,h,v_lon") + + log_probs = torch.log(output["p"]) # b, Ne, mode + log_probs = log_probs[:, :1, None, :, None] # b, 1, S, K, T=1 + + mus = output["trajectories"][:, 0, ..., :2].permute( + 2, 0, 3, 1, 4 + ) # (Na,b,T,M,2) + Na, bs, Tf, M = mus.shape[:4] + log_pis = ( + torch.log(output["p"][None, :, None, 0]) + .repeat_interleave(Na, 0) + .repeat_interleave(Tf, 2) + ) + log_sigmas = torch.zeros_like(mus) + corrs = torch.zeros_like(log_pis) + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + if "state_trajectory" in output and output["state_trajectory"] is not None: + state_traj = output["state_trajectory"][None, :, 0] + # changing state order + state_traj = torch.cat( + [state_traj[..., :2], state_traj[..., 3:], state_traj[..., 2:3]], -1 + ) + output["pred_ml"] = state_traj + else: + output["pred_ml"] = output["trajectories"][None, :, 0] + output["pred_dist"] = y_dists + + output.update(dict(data_batch=parsed_batch)) + if run_mode == RunMode.INFER: + # Convert to standardized prediction output + dt = self.hyperparams["step_time"] # hyperparams is AgentFormerConfig + mus_xyh = output["trajectories"] # b, Ne, mode, N_agent, t, xyh + log_pis = torch.log(output["p"]) # b, Ne, mode + + # pred_dist: GMM + mus_xyh = mus_xyh[:, 0] # b, mode, N_agent, t, xyh + mus_xyh = mus_xyh.permute(2, 0, 1, 3, 4) # (N_agent, b, mode, T, xyh) + # Infer velocity from xy + mus_xyhv = traj_xyh_to_xyhv(mus_xyh, dt) + mus_xyhv = mus_xyhv.transpose(2, 3) # (N_agent, b, T, mode, xyhv) + + # Currently we simply treat joint distribtion as agent-wise marginals. + log_pis = ( + log_pis[:, 0] + .reshape(1, log_pis.shape[0], 1, log_pis.shape[2]) + .repeat(mus_xyhv.shape[0], 1, mus_xyhv.shape[2], 1) + ) # n, b, T, mode + log_sigmas = torch.log( + ( + torch.arange( + 1, + mus_xyhv.shape[2] + 1, + dtype=mus_xyhv.dtype, + device=mus_xyhv.device, + ) + * dt + ) + ** 2 + * 2 + ) + log_sigmas = log_sigmas.reshape(1, 1, mus_xyhv.shape[2], 1, 1).repeat( + (mus_xyhv.shape[0], mus_xyhv.shape[1], 1, mus_xyhv.shape[3], 2) + ) + corrs = 0.0 * torch.ones( + mus_xyhv.shape[:-1], dtype=mus_xyhv.dtype, device=mus_xyhv.device + ) + + pred_dist_with_ego = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) + + # drop ego + if isinstance(inputs["scene_batch"], SceneBatch): + assert (inputs["scene_batch"].extras["robot_ind"] <= 0).all() + pred_dist = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) + + ml_mode_ind = torch.argmax(log_pis, dim=-1) # n, b, T + # pred_ml = batch_select(mus_xyhv, ml_mode_ind, 3) # n, b, T, 4 + pred_ml = mus_xyhv.permute(1, 3, 0, 2, 4) + + # Dummy single agent prediction. + if isinstance(inputs["scene_batch"], SceneBatch): + agent_fut = inputs["scene_batch"].agent_fut + else: + agent_fut = inputs["scene_batch"]["agent_fut"] + pred_single = torch.full( + [agent_fut.shape[0], 0, agent_fut.shape[2], 4], + dtype=agent_fut.dtype, + device=agent_fut.device, + fill_value=torch.nan, + ) + + output["pred_dist"] = pred_dist + output["pred_dist_with_ego"] = pred_dist_with_ego + output["pred_ml"] = pred_ml + output["pred_single"] = pred_single + output["metrics"] = {} + else: + output["pred_dist"] = None + output["pred_dist_with_ego"] = None + output["pred_ml"] = None + output["pred_single"] = None + output["metrics"] = {} + output["step_time"] = self.cfg["step_time"] + return output + + def compute_losses(self, pred_batch, inputs): + return self.nets["policy"].compute_losses(pred_batch, None) + + def compute_metrics(self, pred_batch, data_batch): + EPS = 1e-3 + metrics = dict() + # calculate GT lane mode and homotopy + batch = pred_batch["data_batch"] + fut_mask = batch["fut_mask"] + mode_valid_flag = fut_mask.all(-1) + B, N, Tf = batch["agent_fut"].shape[:3] + traj = pred_batch["trajectories"].view(B, -1, N, Tf, 3) + if True: + lane_mask = batch["lane_mask"] + fut_xy = batch["agent_fut"][..., :2] + fut_sc = batch["agent_fut"][..., 6:8] + fut_sc = GeoUtils.normalize_sc(fut_sc) + fut_xysc = torch.cat([fut_xy, fut_sc], -1) + + end_points = fut_xysc[:, :, -1] # Only look at final time for GT! + lane_xyh = batch["lane_xyh"] + M = lane_xyh.size(1) + lane_xysc = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + + GT_lane_mode, _ = SimpleLaneRelation.categorize_lane_relation_pts( + end_points.reshape(B * N, 1, 4), + lane_xysc.repeat_interleave(N, 0), + fut_mask.any(-1).reshape(B * N, 1), + lane_mask.repeat_interleave(N, 0), + force_select=False, + ) + # You could have two lanes that it is both on + + GT_lane_mode = GT_lane_mode.squeeze(-2).argmax(-1).reshape(B, N, M) + GT_lane_mode = torch.cat( + [GT_lane_mode, (GT_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, GT_homotopy = identify_pairwise_homotopy(fut_xy, mask=fut_mask) + GT_homotopy = GT_homotopy.type(torch.int64).reshape(B, N, N) + pred_batch["GT_lane_mode"] = GT_lane_mode + pred_batch["GT_homotopy"] = GT_homotopy + + pred_xysc = torch.cat( + [traj[..., :2], torch.sin(traj[..., 2:3]), torch.cos(traj[..., 2:3])], + -1, + ) + DS = pred_xysc.size(1) + + end_points = pred_xysc[:, :, :, -1] # Only look at final time + + pred_lane_mode, _ = SimpleLaneRelation.categorize_lane_relation_pts( + end_points.reshape(B * N * DS, 1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + fut_mask.any(-1).repeat_interleave(DS, 0).reshape(B * DS * N, 1), + lane_mask.repeat_interleave(DS * N, 0), + force_select=False, + ) + # You could have two lanes that it is both on + + pred_lane_mode = pred_lane_mode.squeeze(-2).argmax(-1).reshape(B, DS, N, M) + pred_lane_mode = torch.cat( + [pred_lane_mode, (pred_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, pred_homotopy = identify_pairwise_homotopy( + pred_xysc[..., :2].view(B * DS, N, Tf, 2), + mask=fut_mask.repeat_interleave(DS, 0), + ) + pred_homotopy = pred_homotopy.type(torch.int64).reshape(B, DS, N, N) + ML_homotopy_flag = (pred_homotopy[:, 0] == GT_homotopy).all(-1) + + ML_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + + metrics["ML_homotopy_correct_rate"] = TensorUtils.to_numpy( + (ML_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_homotopy_flag = (pred_homotopy == GT_homotopy[:, None]).any(1).all(-1) + all_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + metrics["all_homotopy_correct_rate"] = TensorUtils.to_numpy( + (all_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_lane_mode_flag = (pred_lane_mode[:, 0] == GT_lane_mode).all(-1) + all_lane_mode_flag = ( + (pred_lane_mode == GT_lane_mode[:, None]).any(1).all(-1) + ) + metrics["ML_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + metrics["all_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (all_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_scene_mode_flag = (pred_homotopy[:, 0] == GT_homotopy).all(-1) & ( + pred_lane_mode[:, 0] == GT_lane_mode + ).all(-1) + metrics["ML_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_scene_mode_flag = ( + (pred_homotopy == GT_homotopy[:, None]).all(-1) + & (pred_lane_mode == GT_lane_mode[:, None]).all(-1) + ).any(1) + metrics["all_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (all_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + + if "GT_homotopy" in pred_batch: + # train/validation mode + + agent_fut, fut_mask, pred_traj = TensorUtils.to_numpy( + (batch["agent_fut"], batch["fut_mask"], pred_batch["trajectories"]) + ) + pred_traj = pred_traj.reshape([B, -1, N, Tf, 3]) + if pred_traj.shape[-2] != agent_fut.shape[-2]: + return metrics + a2a_valid_flag = mode_valid_flag.unsqueeze(-1) * mode_valid_flag.unsqueeze( + -2 + ) + + extent = batch["extent"] + DS = traj.size(1) + pred_edges = self.bu.generate_edges( + batch["agent_type"].repeat_interleave(DS, 0), + extent.repeat_interleave(DS, 0), + TensorUtils.join_dimensions(traj[..., :2], 0, 2), + TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), + batch_first=True, + ) + + coll_loss, dis = collision_loss(pred_edges=pred_edges, return_dis=True) + dis_padded = { + k: v.nan_to_num(1.0).view(B, DS, -1, Tf) for k, v in dis.items() + } + edge_mask = { + k: torch.logical_not(v.isnan().any(-1)).view(B, DS, -1) + for k, v in dis.items() + } + for k in dis_padded: + metrics["collision_rate_ML_" + k] = TensorUtils.to_numpy( + (dis_padded[k][:, 0] < 0).sum() + / (edge_mask[k][:, 0].sum() + EPS) + / Tf + ).item() + metrics[f"collision_rate_all_{DS}_mode_{k}"] = TensorUtils.to_numpy( + (dis_padded[k] < 0).sum() / (edge_mask[k].sum() + EPS) / Tf + ).item() + metrics["collision_rate_ML"] = TensorUtils.to_numpy( + sum([(v[:, 0] < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k][:, 0].sum() for k in edge_mask]) + EPS) + ).item() + metrics[f"collision_rate_all_{DS}_mode"] = TensorUtils.to_numpy( + sum([(v < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k].sum() for k in edge_mask]) + EPS) + ).item() + + confidence = np.ones([B * N, 1]) + Nmode = pred_traj.shape[1] + agent_type = TensorUtils.to_numpy(batch["agent_type"]) + vehicle_mask = ( + (agent_type == AgentType.VEHICLE) + | (agent_type == AgentType.BICYCLE) + | (agent_type == AgentType.MOTORCYCLE) + ) + pedestrian_mask = agent_type == AgentType.PEDESTRIAN + agent_mask = fut_mask.any(-1) + dt = self.cfg.step_time + for Tsecond in [3.0, 4.0, 5.0, 8.0]: + if Tf < Tsecond / dt: + continue + Tf_bar = int(Tsecond / dt) + + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + metrics[f"ML_ADE@{Tsecond}"] = allADE + metrics[f"ML_FDE@{Tsecond}"] = allFDE + metrics[f"ML_vehicle_ADE@{Tsecond}"] = vehADE + metrics[f"ML_vehicle_FDE@{Tsecond}"] = vehFDE + metrics[f"oracle_pedestrian_ADE@{Tsecond}"] = pedADE + metrics[f"oracle_pedestrian_FDE@{Tsecond}"] = pedFDE + + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode, 1) / Nmode, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode, 1) / Nmode, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + + metrics[f"min_ADE@{Tsecond}s"] = allADE + metrics[f"min_FDE@{Tsecond}s"] = allFDE + metrics[f"min_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"min_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"min_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"min_pedestrian_FDE@{Tsecond}s"] = pedFDE + + return metrics diff --git a/diffstack/modules/predictors/trajectron_utils/environment.py b/diffstack/modules/predictors/trajectron_utils/environment.py new file mode 100644 index 0000000..1787914 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment.py @@ -0,0 +1,79 @@ +import orjson +import numpy as np +from itertools import product +from .node_type import NodeTypeEnum + + +class Environment(object): + def __init__(self, node_type_list, standardization, scenes=None, attention_radius=None, robot_type=None, dt=None): + self.scenes = scenes + self.node_type_list = node_type_list + self.attention_radius = attention_radius + self.NodeType = NodeTypeEnum(node_type_list) + self.robot_type = robot_type + self.dt = dt + + self.standardization = standardization + self.standardize_param_memo = dict() + + self._scenes_resample_prop = None + + def get_edge_types(self): + return list(product(self.NodeType, repeat=2)) + + def get_standardize_params(self, state, node_type): + memo_key = (orjson.dumps(state), node_type) + if memo_key in self.standardize_param_memo: + return self.standardize_param_memo[memo_key] + + standardize_mean_list = list() + standardize_std_list = list() + for entity, dims in state.items(): + for dim in dims: + standardize_mean_list.append(self.standardization[node_type][entity][dim]['mean']) + standardize_std_list.append(self.standardization[node_type][entity][dim]['std']) + standardize_mean = np.stack(standardize_mean_list) + standardize_std = np.stack(standardize_std_list) + + self.standardize_param_memo[memo_key] = (standardize_mean, standardize_std) + return standardize_mean, standardize_std + + def standardize(self, array, state, node_type, mean=None, std=None): + if mean is None and std is None: + mean, std = self.get_standardize_params(state, node_type) + elif mean is None and std is not None: + mean, _ = self.get_standardize_params(state, node_type) + elif mean is not None and std is None: + _, std = self.get_standardize_params(state, node_type) + return np.where(np.isnan(array), np.array(np.nan), (array - mean) / std) + + def unstandardize(self, array, state, node_type, mean=None, std=None): + if mean is None and std is None: + mean, std = self.get_standardize_params(state, node_type) + elif mean is None and std is not None: + mean, _ = self.get_standardize_params(state, node_type) + elif mean is not None and std is None: + _, std = self.get_standardize_params(state, node_type) + return array * std + mean + + @property + def scenes_resample_prop(self): + if self._scenes_resample_prop is None: + self._scenes_resample_prop = np.array([scene.resample_prob for scene in self.scenes]) + self._scenes_resample_prop = self._scenes_resample_prop / np.sum(self._scenes_resample_prop) + return self._scenes_resample_prop + + +class EnvironmentMetadata(Environment): + """The purpose of this class is to provide the exact same data that an Environment object does, but without the + huge scenes list (which makes this easy to serialize for pickling, e.g., for multiprocessing). + """ + def __init__(self, env): + super(EnvironmentMetadata, self).__init__(node_type_list=env.node_type_list, + standardization=env.standardization, + scenes=None, + attention_radius=env.attention_radius, + robot_type=env.robot_type, + dt=env.dt) + self.standardize_param_memo = env.standardize_param_memo + self._scenes_resample_prop = env._scenes_resample_prop diff --git a/diffstack/modules/predictors/trajectron_utils/environment/__init__.py b/diffstack/modules/predictors/trajectron_utils/environment/__init__.py new file mode 100644 index 0000000..c75c209 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/__init__.py @@ -0,0 +1,8 @@ +from .data_structures import RingBuffer, SingleHeaderNumpyArray, DoubleHeaderNumpyArray +from .scene import Scene +from .node import Node +from .scene_graph import TemporalSceneGraph, SceneGraph +from .environment import Environment, EnvironmentMetadata +from .node_type import NodeTypeEnum +from .data_utils import derivative_of, gradient_of +from .map import GeometricMap diff --git a/diffstack/modules/predictors/trajectron_utils/environment/data_structures.py b/diffstack/modules/predictors/trajectron_utils/environment/data_structures.py new file mode 100644 index 0000000..5ffdc6f --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/data_structures.py @@ -0,0 +1,276 @@ +import numpy as np +import pandas as pd +from collections import Sequence, OrderedDict + + +class RingBuffer(Sequence): + def __init__(self, capacity, dtype=float, allow_overwrite=True): + """ + Create a new ring buffer with the given capacity and element type. + Code copy-pasted from: https://github.com/eric-wieser/numpy_ringbuffer + + Parameters + ---------- + capacity: int + The maximum capacity of the ring buffer + dtype: data-type, optional + Desired type of buffer elements. Use a type like (float, 2) to + produce a buffer with shape (N, 2) + allow_overwrite: bool + If false, throw an IndexError when trying to append to an already + full buffer + """ + self._arr = np.full(capacity, np.nan, dtype) + self._left_index = 0 + self._right_index = 0 + self._capacity = capacity + self._allow_overwrite = allow_overwrite + + def _unwrap(self): + """ Copy the data from this buffer into unwrapped form """ + return np.concatenate(( + self._arr[self._left_index:min(self._right_index, self._capacity)], + self._arr[:max(self._right_index - self._capacity, 0)] + )) + + def _fix_indices(self): + """ + Enforce our invariant that 0 <= self._left_index < self._capacity + """ + if self._left_index >= self._capacity: + self._left_index -= self._capacity + self._right_index -= self._capacity + elif self._left_index < 0: + self._left_index += self._capacity + self._right_index += self._capacity + + @property + def is_full(self): + """ True if there is no more space in the buffer """ + return len(self) == self._capacity + + # numpy compatibility + def __array__(self): + return self._unwrap() + + @property + def dtype(self): + return self._arr.dtype + + @property + def shape(self): + return (len(self),) + self._arr.shape[1:] + + # these mirror methods from deque + @property + def maxlen(self): + return self._capacity + + def append(self, value): + if self.is_full: + if not self._allow_overwrite: + raise IndexError('append to a full RingBuffer with overwrite disabled') + elif not len(self): + return + else: + self._left_index += 1 + + self._arr[self._right_index % self._capacity] = value + self._right_index += 1 + self._fix_indices() + + def appendleft(self, value): + if self.is_full: + if not self._allow_overwrite: + raise IndexError('append to a full RingBuffer with overwrite disabled') + elif not len(self): + return + else: + self._right_index -= 1 + + self._left_index -= 1 + self._fix_indices() + self._arr[self._left_index] = value + + def pop(self): + if len(self) == 0: + raise IndexError("pop from an empty RingBuffer") + self._right_index -= 1 + self._fix_indices() + res = self._arr[self._right_index % self._capacity] + return res + + def popleft(self): + if len(self) == 0: + raise IndexError("pop from an empty RingBuffer") + res = self._arr[self._left_index] + self._left_index += 1 + self._fix_indices() + return res + + def extend(self, values): + lv = len(values) + if len(self) + lv > self._capacity: + if not self._allow_overwrite: + raise IndexError('extend a RingBuffer such that it would overflow, with overwrite disabled') + elif not len(self): + return + if lv >= self._capacity: + # wipe the entire array! - this may not be threadsafe + self._arr[...] = values[-self._capacity:] + self._right_index = self._capacity + self._left_index = 0 + return + + ri = self._right_index % self._capacity + sl1 = np.s_[ri:min(ri + lv, self._capacity)] + sl2 = np.s_[:max(ri + lv - self._capacity, 0)] + self._arr[sl1] = values[:sl1.stop - sl1.start] + self._arr[sl2] = values[sl1.stop - sl1.start:] + self._right_index += lv + + self._left_index = max(self._left_index, self._right_index - self._capacity) + self._fix_indices() + + def extendleft(self, values): + lv = len(values) + if len(self) + lv > self._capacity: + if not self._allow_overwrite: + raise IndexError('extend a RingBuffer such that it would overflow, with overwrite disabled') + elif not len(self): + return + if lv >= self._capacity: + # wipe the entire array! - this may not be threadsafe + self._arr[...] = values[:self._capacity] + self._right_index = self._capacity + self._left_index = 0 + return + + self._left_index -= lv + self._fix_indices() + li = self._left_index + sl1 = np.s_[li:min(li + lv, self._capacity)] + sl2 = np.s_[:max(li + lv - self._capacity, 0)] + self._arr[sl1] = values[:sl1.stop - sl1.start] + self._arr[sl2] = values[sl1.stop - sl1.start:] + + self._right_index = min(self._right_index, self._left_index + self._capacity) + + # implement Sequence methods + def __len__(self): + return self._right_index - self._left_index + + def __getitem__(self, item): + # handle simple (b[1]) and basic (b[np.array([1, 2, 3])]) fancy indexing specially + if not isinstance(item, tuple): + item_arr = np.asarray(item) + if issubclass(item_arr.dtype.type, np.integer): + item_arr = (item_arr + self._left_index) % self._capacity + return self._arr[item_arr] + + # for everything else, get it right at the expense of efficiency + return self._unwrap()[item] + + def __iter__(self): + # alarmingly, this is comparable in speed to using itertools.chain + return iter(self._unwrap()) + + # Everything else + def __repr__(self): + return ''.format(np.asarray(self)) + + +class DoubleHeaderNumpyArray(object): + def __init__(self, data: np.ndarray, header: list): + """ + Data Structure mirroring some functionality of double indexed pandas DataFrames. + Indexing options are: + [:, (header1, header2)] + [:, [(header1, header2), (header1, header2)]] + [:, {header1: [header21, header22]}] + + A SingleHeaderNumpyArray can is returned if an element of the first header is querried as attribut: + doubleHeaderNumpyArray.position -> SingleHeaderNumpyArray + + :param data: The numpy array. + :param header: The double header structure as list of tuples [(header11, header21), (header11, header22) ...] + """ + self.data = data + self.header = header + self.double_header_lookup = OrderedDict() + self.tree_header_lookup = OrderedDict() + for i, header_item in enumerate(header): + self.double_header_lookup[header_item] = i + if header_item[0] not in self.tree_header_lookup: + self.tree_header_lookup[header_item[0]] = dict() + self.tree_header_lookup[header_item[0]][header_item[1]] = i + + def __mul__(self, other): + return DoubleHeaderNumpyArray(self.data * other, self.header) + + def get_single_header_array(self, h1: str, rows=slice(None, None, None)): + data_integer_indices = list() + h2_list = list() + for h2 in self.tree_header_lookup[h1]: + data_integer_indices.append(self.tree_header_lookup[h1][h2]) + h2_list.append(h2) + return SingleHeaderNumpyArray(self.data[rows, data_integer_indices], h2_list) + + def __getitem__(self, item): + rows, columns = item + data_integer_indices = list() + if type(columns) is dict: + for h1, h2s in columns.items(): + for h2 in h2s: + data_integer_indices.append(self.double_header_lookup[(h1, h2)]) + return self.data[rows, data_integer_indices] + elif type(columns) is list: + for column in columns: + assert type(column) is tuple, "If Index is list it hast to be list of double header tuples." + data_integer_indices.append(self.double_header_lookup[column]) + return self.data[rows, data_integer_indices] + elif type(columns) is tuple: + return self.data[rows, self.double_header_lookup[columns]] + else: + assert type(item) is str, "Index must be str, list of tuples or dict of tree structure." + return self.get_single_header_array(item, rows=rows) + + def __getattr__(self, item): + if not item.startswith('_'): + if item in self.tree_header_lookup.keys(): + return self.get_single_header_array(item) + else: + try: + return self.data.__getattribute__(item) + except AttributeError: + return super().__getattribute__(item) + else: + return super().__getattribute__(item) + + +class SingleHeaderNumpyArray(object): + def __init__(self, data: np.ndarray, header: list): + self.data = data + self.header_lookup = OrderedDict({h: i for i, h in enumerate(header)}) + + def __getitem__(self, item): + rows, columns = item + data_integer_indices = list() + if type(columns) is list or type(columns) is tuple: + for column in columns: + data_integer_indices.append(self.header_lookup[column]) + else: + data_integer_indices = self.header_lookup[columns] + return self.data[rows, data_integer_indices] + + def __getattr__(self, item): + if not item.startswith('_'): + if item in self.header_lookup.keys(): + return self[:, item] + else: + try: + return self.data.__getattribute__(item) + except AttributeError: + return super().__getattribute__(item) + else: + return super().__getattribute__(item) diff --git a/diffstack/modules/predictors/trajectron_utils/environment/data_utils.py b/diffstack/modules/predictors/trajectron_utils/environment/data_utils.py new file mode 100644 index 0000000..67cc948 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/data_utils.py @@ -0,0 +1,45 @@ +import numpy as np + + +def make_continuous_copy(alpha): + alpha = (alpha + np.pi) % (2.0 * np.pi) - np.pi + continuous_x = np.zeros_like(alpha) + continuous_x[0] = alpha[0] + for i in range(1, len(alpha)): + if not (np.sign(alpha[i]) == np.sign(alpha[i - 1])) and np.abs(alpha[i]) > np.pi / 2: + continuous_x[i] = continuous_x[i - 1] + ( + alpha[i] - alpha[i - 1]) - np.sign( + (alpha[i] - alpha[i - 1])) * 2 * np.pi + else: + continuous_x[i] = continuous_x[i - 1] + (alpha[i] - alpha[i - 1]) + + return continuous_x + + +def derivative_of(x, dt=1, radian=False): + if radian: + x = make_continuous_copy(x) + + not_nan_mask = ~np.isnan(x) + masked_x = x[not_nan_mask] + + if masked_x.shape[-1] < 2: + return np.zeros_like(x) + + dx = np.full_like(x, np.nan) + dx[not_nan_mask] = np.ediff1d(masked_x, to_begin=(masked_x[1] - masked_x[0])) / dt + + return dx + + +def gradient_of(x, dt=1, radian=False): + if radian: + x = make_continuous_copy(x) + + if x[~np.isnan(x)].shape[-1] < 2: + return np.zeros_like(x) + + dx = np.full_like(x, np.nan) + dx[~np.isnan(x)] = np.gradient(x[~np.isnan(x)], dt) + + return dx diff --git a/diffstack/modules/predictors/trajectron_utils/environment/environment.py b/diffstack/modules/predictors/trajectron_utils/environment/environment.py new file mode 100644 index 0000000..1787914 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/environment.py @@ -0,0 +1,79 @@ +import orjson +import numpy as np +from itertools import product +from .node_type import NodeTypeEnum + + +class Environment(object): + def __init__(self, node_type_list, standardization, scenes=None, attention_radius=None, robot_type=None, dt=None): + self.scenes = scenes + self.node_type_list = node_type_list + self.attention_radius = attention_radius + self.NodeType = NodeTypeEnum(node_type_list) + self.robot_type = robot_type + self.dt = dt + + self.standardization = standardization + self.standardize_param_memo = dict() + + self._scenes_resample_prop = None + + def get_edge_types(self): + return list(product(self.NodeType, repeat=2)) + + def get_standardize_params(self, state, node_type): + memo_key = (orjson.dumps(state), node_type) + if memo_key in self.standardize_param_memo: + return self.standardize_param_memo[memo_key] + + standardize_mean_list = list() + standardize_std_list = list() + for entity, dims in state.items(): + for dim in dims: + standardize_mean_list.append(self.standardization[node_type][entity][dim]['mean']) + standardize_std_list.append(self.standardization[node_type][entity][dim]['std']) + standardize_mean = np.stack(standardize_mean_list) + standardize_std = np.stack(standardize_std_list) + + self.standardize_param_memo[memo_key] = (standardize_mean, standardize_std) + return standardize_mean, standardize_std + + def standardize(self, array, state, node_type, mean=None, std=None): + if mean is None and std is None: + mean, std = self.get_standardize_params(state, node_type) + elif mean is None and std is not None: + mean, _ = self.get_standardize_params(state, node_type) + elif mean is not None and std is None: + _, std = self.get_standardize_params(state, node_type) + return np.where(np.isnan(array), np.array(np.nan), (array - mean) / std) + + def unstandardize(self, array, state, node_type, mean=None, std=None): + if mean is None and std is None: + mean, std = self.get_standardize_params(state, node_type) + elif mean is None and std is not None: + mean, _ = self.get_standardize_params(state, node_type) + elif mean is not None and std is None: + _, std = self.get_standardize_params(state, node_type) + return array * std + mean + + @property + def scenes_resample_prop(self): + if self._scenes_resample_prop is None: + self._scenes_resample_prop = np.array([scene.resample_prob for scene in self.scenes]) + self._scenes_resample_prop = self._scenes_resample_prop / np.sum(self._scenes_resample_prop) + return self._scenes_resample_prop + + +class EnvironmentMetadata(Environment): + """The purpose of this class is to provide the exact same data that an Environment object does, but without the + huge scenes list (which makes this easy to serialize for pickling, e.g., for multiprocessing). + """ + def __init__(self, env): + super(EnvironmentMetadata, self).__init__(node_type_list=env.node_type_list, + standardization=env.standardization, + scenes=None, + attention_radius=env.attention_radius, + robot_type=env.robot_type, + dt=env.dt) + self.standardize_param_memo = env.standardize_param_memo + self._scenes_resample_prop = env._scenes_resample_prop diff --git a/diffstack/modules/predictors/trajectron_utils/environment/map.py b/diffstack/modules/predictors/trajectron_utils/environment/map.py new file mode 100644 index 0000000..4b20ddf --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/map.py @@ -0,0 +1,229 @@ +import torch +import numpy as np +from diffstack.modules.predictors.trajectron_utils.model.dataset.homography_warper import ( + get_rotation_matrix2d, + warp_affine_crop, +) + + +class Map(object): + def __init__(self, data, homography, description=None): + self.data = data + self.homography = homography + self.description = description + + def as_image(self): + raise NotImplementedError + + def get_cropped_maps(self, world_pts, patch_size, rotation=None, device="cpu"): + raise NotImplementedError + + def to_map_points(self, scene_pts): + raise NotImplementedError + + +class GeometricMap(Map): + """ + A Geometric Map is a int tensor of shape [layers, x, y]. The homography must transform a point in scene + coordinates to the respective point in map coordinates. + + :param data: Numpy array of shape [layers, x, y] + :param homography: Numpy array of shape [3, 3] + """ + + def __init__(self, data, homography, description=None): + # assert isinstance(data.dtype, np.floating), "Geometric Maps must be float values." + super(GeometricMap, self).__init__(data, homography, description=description) + + self._last_padding = None + self._last_padded_map = None + self._torch_map = None + + def torch_map(self, device): + if self._torch_map is not None: + return self._torch_map + self._torch_map = torch.tensor(self.data, dtype=torch.uint8, device=device) + return self._torch_map + + def as_image(self): + # We have to transpose x and y to rows and columns. Assumes origin is lower left for image + # Also we move the channels to the last dimension + return (np.transpose(self.data, (2, 1, 0))).astype(np.uint) + + def get_padded_map(self, padding_x, padding_y, device): + if self._last_padding == (padding_x, padding_y): + return self._last_padded_map + else: + self._last_padding = (padding_x, padding_y) + self._last_padded_map = torch.full( + ( + self.data.shape[0], + self.data.shape[1] + 2 * padding_x, + self.data.shape[2] + 2 * padding_y, + ), + False, + dtype=torch.uint8, + ) + self._last_padded_map[ + ..., padding_x:-padding_x, padding_y:-padding_y + ] = self.torch_map(device) + return self._last_padded_map + + @staticmethod + def batch_rotate(map_batched, centers, angles, out_height, out_width): + """ + As the input is a map and the warp_affine works on an image coordinate system we would have to + flip the y axis updown, negate the angles, and flip it back after transformation. + This, however, is the same as not flipping at and not negating the radian. + + :param map_batched: + :param centers: + :param angles: + :param out_height: + :param out_width: + :return: + """ + M = get_rotation_matrix2d(centers, angles, torch.ones_like(angles)) + rotated_map_batched = warp_affine_crop( + map_batched, centers, M, dsize=(out_height, out_width), padding_mode="zeros" + ) + + return rotated_map_batched + + @classmethod + def get_cropped_maps_from_scene_map_batch( + cls, maps, scene_pts, patch_size, rotation=None, device="cpu" + ): + """ + Returns rotated patches of each map around the transformed scene points. + ___________________ + | | | + | |ps[3] | + | | | + | | | + | o|__________| + | | ps[2] | + | | | + |_______|__________| + ps = patch_size + + :param maps: List of GeometricMap objects [bs] + :param scene_pts: Scene points: [bs, 2] + :param patch_size: Extracted Patch size after rotation: [-x, -y, +x, +y] + :param rotation: Rotations in degrees: [bs] + :param device: Device on which the rotated tensors should be returned. + :return: Rotated and cropped tensor patches. + """ + batch_size = scene_pts.shape[0] + lat_size = 2 * np.max((patch_size[0], patch_size[2])) + long_size = 2 * np.max((patch_size[1], patch_size[3])) + assert lat_size % 2 == 0, "Patch width must be divisible by 2" + assert long_size % 2 == 0, "Patch length must be divisible by 2" + lat_size_half = lat_size // 2 + long_size_half = long_size // 2 + + context_padding_x = int(np.ceil(np.sqrt(2) * lat_size)) + context_padding_y = int(np.ceil(np.sqrt(2) * long_size)) + + centers = torch.tensor( + [ + s_map.to_map_points(scene_pts[np.newaxis, i]) + for i, s_map in enumerate(maps) + ], + dtype=torch.long, + device=device, + ).squeeze(dim=1) + torch.tensor( + [context_padding_x, context_padding_y], device=device, dtype=torch.long + ) + + padded_map = [ + s_map.get_padded_map(context_padding_x, context_padding_y, device=device) + for s_map in maps + ] + + padded_map_batched = torch.stack( + [ + padded_map[i][ + ..., + centers[i, 0] + - context_padding_x : centers[i, 0] + + context_padding_x, + centers[i, 1] + - context_padding_y : centers[i, 1] + + context_padding_y, + ] + for i in range(centers.shape[0]) + ], + dim=0, + ) + + center_patches = torch.tensor( + [[context_padding_y, context_padding_x]], dtype=torch.int, device=device + ).repeat(batch_size, 1) + + if rotation is not None: + angles = torch.Tensor(rotation) + else: + angles = torch.zeros(batch_size) + + rotated_map_batched = cls.batch_rotate( + padded_map_batched / 255.0, + center_patches.float(), + angles, + long_size, + lat_size, + ) + + del padded_map_batched + + return rotated_map_batched[ + ..., + long_size_half - patch_size[1] : (long_size_half + patch_size[3]), + lat_size_half - patch_size[0] : (lat_size_half + patch_size[2]), + ] + + def get_cropped_maps(self, scene_pts, patch_size, rotation=None, device="cpu"): + """ + Returns rotated patches of the map around the transformed scene points. + ___________________ + | | | + | |ps[3] | + | | | + | | | + | o|__________| + | | ps[2] | + | | | + |_______|__________| + ps = patch_size + + :param scene_pts: Scene points: [bs, 2] + :param patch_size: Extracted Patch size after rotation: [-lat, -long, +lat, +long] + :param rotation: Rotations in degrees: [bs] + :param device: Device on which the rotated tensors should be returned. + :return: Rotated and cropped tensor patches. + """ + return self.get_cropped_maps_from_scene_map_batch( + [self] * scene_pts.shape[0], + scene_pts, + patch_size, + rotation=rotation, + device=device, + ) + + def to_map_points(self, scene_pts): + org_shape = None + if len(scene_pts.shape) > 2: + org_shape = scene_pts.shape + scene_pts = scene_pts.reshape((-1, 2)) + N, dims = scene_pts.shape + points_with_one = np.ones((dims + 1, N)) + points_with_one[:dims] = scene_pts.T + map_points = (self.homography @ points_with_one).T[..., :dims] + if org_shape is not None: + map_points = map_points.reshape(org_shape) + return map_points + + +class ImageMap(Map): + def __init__(self): + raise NotImplementedError diff --git a/diffstack/modules/predictors/trajectron_utils/environment/node.py b/diffstack/modules/predictors/trajectron_utils/environment/node.py new file mode 100644 index 0000000..22f202e --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/node.py @@ -0,0 +1,265 @@ +import random +import numpy as np +import pandas as pd +from diffstack.modules.predictors.trajectron_utils.environment import DoubleHeaderNumpyArray +from ncls import NCLS +from typing import Tuple + + +class Node(object): + def __init__(self, node_type, node_id, data, length=None, width=None, height=None, first_timestep=0, + is_robot=False, description="", frequency_multiplier=1, non_aug_node=None, extra_data=None): + self.type = node_type + self.id = node_id + self.length = length + self.width = width + self.height = height + self.first_timestep = first_timestep + self.non_aug_node = non_aug_node + + if data is not None: + if isinstance(data, pd.DataFrame): + self.data = DoubleHeaderNumpyArray(data.values, list(data.columns)) + elif isinstance(data, DoubleHeaderNumpyArray): + self.data = data + else: + self.data = None + + self.extra_data = extra_data + + self.is_robot = is_robot + self._last_timestep = None + self.description = description + self.frequency_multiplier = frequency_multiplier + + self.forward_in_time_on_next_override = False + + def __eq__(self, other): + return ((isinstance(other, self.__class__) + or isinstance(self, other.__class__)) + and self.id == other.id + and self.type == other.type) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.type, self.id)) + + def __repr__(self): + return '/'.join([self.type.name, self.id]) + + def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False): + """ + This function hard overwrites the data matrix. When using it you have to make sure that the columns + in the new data matrix correspond to the old structure. As well as setting first_timestep. + + :param data: New data matrix + :param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased. + :return: None + """ + if header is None: + self.data.data = data + else: + self.data = DoubleHeaderNumpyArray(data, header) + + self._last_timestep = None + if self.forward_in_time_on_next_override: + self.first_timestep += 1 + self.forward_in_time_on_next_override = forward_in_time_on_next_overwrite + + def scene_ts_to_node_ts(self, scene_ts) -> Tuple[np.ndarray, int, int]: + """ + Transforms timestamp from scene into timeframe of node data. + + :param scene_ts: Scene timesteps + :return: ts: Transformed timesteps, paddingl: Number of timesteps in scene range which are not available in + node data before data is available. paddingu: Number of timesteps in scene range which are not + available in node data after data is available. + """ + paddingl = (self.first_timestep - scene_ts[0]).clip(0) + paddingu = (scene_ts[1] - self.last_timestep).clip(0) + ts = np.array(scene_ts).clip(min=self.first_timestep, max=self.last_timestep) - self.first_timestep + return ts, paddingl, paddingu + + def history_points_at(self, ts) -> int: + """ + Number of history points in trajectory. Timestep is exclusive. + + :param ts: Scene timestep where the number of history points are queried. + :return: Number of history timesteps. + """ + return ts - self.first_timestep + + def get(self, tr_scene, state, padding=np.nan) -> np.ndarray: + """ + Returns a time range of multiple properties of the node. + + :param tr_scene: The timestep range (inklusive). + :param state: The state description for which the properties are returned. + :param padding: The value which should be used for padding if not enough information is available. + :return: Array of node property values. + """ + if tr_scene.size == 1: + tr_scene = np.array([tr_scene[0], tr_scene[0]]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + data_array = self.data[tr[0]:tr[1] + 1, state] + padded_data_array = np.full((length, data_array.shape[1]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + return padded_data_array + + def get_lane_points(self, tr_scene, padding=np.nan, num_lane_points=16) -> np.ndarray: + """ + :param tr_scene: The timestep range (inklusive). + """ + if tr_scene.size == 1: + tr_scene = np.array([tr_scene[0], tr_scene[0]]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + data_array = self.extra_data['lane_points'][tr[0]:tr[1] + 1] + padded_data_array = np.full((length, data_array.shape[1], data_array.shape[2]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + # extend to fixed num lane points + if num_lane_points is not None: + if padded_data_array.shape[1] == 0: + padded_data_array = np.full((length, num_lane_points, 3), fill_value=padding) + elif padded_data_array.shape[1] < num_lane_points: + pad = np.repeat(padded_data_array[:, -1:], num_lane_points-padded_data_array.shape[1], axis=1) + padded_data_array = np.concatenate((padded_data_array, pad), axis=1) + else: + padded_data_array = padded_data_array[:, :num_lane_points] + return padded_data_array + + @property + def timesteps(self) -> int: + """ + Number of available timesteps for node. + + :return: Number of available timesteps. + """ + return self.data.shape[0] + + @property + def last_timestep(self) -> int: + """ + Nodes last timestep in the Scene. + + :return: Nodes last timestep. + """ + if self._last_timestep is None: + self._last_timestep = self.first_timestep + self.timesteps - 1 + return self._last_timestep + + +class MultiNode(Node): + def __init__(self, node_type, node_id, nodes_list, is_robot=False): + super(MultiNode, self).__init__(node_type, node_id, data=None, is_robot=is_robot) + self.nodes_list = nodes_list + for node in self.nodes_list: + node.is_robot = is_robot + + self.first_timestep = min(node.first_timestep for node in self.nodes_list) + self._last_timestep = max(node.last_timestep for node in self.nodes_list) + + starts = np.array([node.first_timestep for node in self.nodes_list], dtype=np.int64) + ends = np.array([node.last_timestep for node in self.nodes_list], dtype=np.int64) + ids = np.arange(len(self.nodes_list), dtype=np.int64) + self.interval_tree = NCLS(starts, ends, ids) + + @staticmethod + def find_non_overlapping_nodes(nodes_list, min_timesteps=1) -> list: + """ + Greedily finds a set of non-overlapping nodes in the provided scene. + + :return: A list of non-overlapping nodes. + """ + non_overlapping_nodes = list() + nodes = sorted(nodes_list, key=lambda n: n.last_timestep) + current_time = 0 + for node in nodes: + if node.first_timestep >= current_time and node.timesteps >= min_timesteps: + # Include the node + non_overlapping_nodes.append(node) + current_time = node.last_timestep + + return non_overlapping_nodes + + def get_node_at_timesteps(self, scene_ts) -> Node: + possible_node_ranges = list(self.interval_tree.find_overlap(scene_ts[0], scene_ts[1] + 1)) + if not possible_node_ranges: + return Node(node_type=self.type, + node_id='EMPTY', + data=self.nodes_list[0].data * np.nan, + is_robot=self.is_robot) + + node_idx = random.choice(possible_node_ranges)[2] + return self.nodes_list[node_idx] + + def scene_ts_to_node_ts(self, scene_ts) -> Tuple[Node, np.ndarray, int, int]: + """ + Transforms timestamp from scene into timeframe of node data. + + :param scene_ts: Scene timesteps + :return: ts: Transformed timesteps, paddingl: Number of timesteps in scene range which are not available in + node data before data is available. paddingu: Number of timesteps in scene range which are not + available in node data after data is available. + """ + possible_node_ranges = list(self.interval_tree.find_overlap(scene_ts[0], scene_ts[1] + 1)) + if not possible_node_ranges: + return None, None, None, None + + node_idx = random.choice(possible_node_ranges)[2] + node = self.nodes_list[node_idx] + + paddingl = (node.first_timestep - scene_ts[0]).clip(0) + paddingu = (scene_ts[1] - node.last_timestep).clip(0) + ts = np.array(scene_ts).clip(min=node.first_timestep, max=node.last_timestep) - node.first_timestep + return node, ts, paddingl, paddingu + + def get(self, tr_scene, state, padding=np.nan) -> np.ndarray: + if tr_scene.size == 1: + tr_scene = np.array([tr_scene, tr_scene]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + + node, tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + if node is None: + state_length = sum([len(entity_dims) for entity_dims in state.values()]) + return np.full((length, state_length), fill_value=padding) + + data_array = node.data[tr[0]:tr[1] + 1, state] + padded_data_array = np.full((length, data_array.shape[1]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + return padded_data_array + + def get_all(self, tr_scene, state, padding=np.nan) -> np.ndarray: + # Assumption here is that the user is asking for all of the data in this MultiNode and to return it within a + # full scene-sized output array. + assert tr_scene.size == 2 and tr_scene[0] == 0 and self.last_timestep <= tr_scene[1] + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + state_length = sum([len(entity_dims) for entity_dims in state.values()]) + padded_data_array = np.full((length, state_length), fill_value=padding) + for node in self.nodes_list: + padded_data_array[node.first_timestep:node.last_timestep + 1] = node.data[:, state] + + return padded_data_array + + def history_points_at(self, ts) -> int: + """ + Number of history points in trajectory. Timestep is exclusive. + + :param ts: Scene timestep where the number of history points are queried. + :return: Number of history timesteps. + """ + node_idx = next(self.interval_tree.find_overlap(ts, ts + 1))[2] + node = self.nodes_list[node_idx] + return ts - node.first_timestep + + @property + def timesteps(self) -> int: + """ + Number of available timesteps for node. + + :return: Number of available timesteps. + """ + return self._last_timestep - self.first_timestep + 1 diff --git a/diffstack/modules/predictors/trajectron_utils/environment/node_type.py b/diffstack/modules/predictors/trajectron_utils/environment/node_type.py new file mode 100644 index 0000000..20b36da --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/node_type.py @@ -0,0 +1,36 @@ +class NodeType(object): + def __init__(self, name, value): + self.name = name + self.value = value + + def __repr__(self): + return self.name + + def __eq__(self, other): + if type(other) == str and self.name == other: + return True + else: + # Only check if class names match, so relative and absolute imports will be treated equal. + return (isinstance(other, self.__class__) or (other.__class__.__name__ == self.__class__.__name__)) and self.name == other.name + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.name) + + def __add__(self, other): + return self.name + other + + +class NodeTypeEnum(list): + def __init__(self, node_type_list): + self.node_type_list = node_type_list + node_types = [NodeType(name, node_type_list.index(name) + 1) for name in node_type_list] + super().__init__(node_types) + + def __getattr__(self, name): + if not name.startswith('_') and name in object.__getattribute__(self, "node_type_list"): + return self[object.__getattribute__(self, "node_type_list").index(name)] + else: + return object.__getattribute__(self, name) diff --git a/diffstack/modules/predictors/trajectron_utils/environment/scene.py b/diffstack/modules/predictors/trajectron_utils/environment/scene.py new file mode 100644 index 0000000..e8a78ea --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/scene.py @@ -0,0 +1,228 @@ +import copy +import numpy as np +from .scene_graph import TemporalSceneGraph, SceneGraph +from .node import MultiNode + + +class Scene(object): + def __init__(self, timesteps, map=None, dt=1, name="", frequency_multiplier=1, aug_func=None, + non_aug_scene=None, x_min=None, y_min=None, x_max=None, y_max=None, map_name=None): + self.map = map + self.timesteps = timesteps + self.dt = dt + self.name = name + + self.x_min = x_min + self.y_min = y_min + self.x_max = x_max + self.y_max = y_max + self.map_name = map_name + + self.nodes = [] + + self.robot = None + + self.temporal_scene_graph = None + + self.frequency_multiplier = frequency_multiplier + + self.description = "" + + self.aug_func = aug_func + self.non_aug_scene = non_aug_scene + + def add_robot_from_nodes(self, robot_type): + scenes = [self] + if hasattr(self, 'augmented'): + scenes += self.augmented + + for scn in scenes: + nodes_list = [node for node in scn.nodes if node.type == robot_type] + non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3) + scn.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True) + + for node in non_overlapping_nodes: + scn.nodes.remove(node) + scn.nodes.append(scn.robot) + + def get_clipped_input_dict(self, timestep, state): + input_dict = dict() + existing_nodes = self.get_nodes_clipped_at_time(timesteps=np.array([timestep]), + state=state) + tr_scene = np.array([timestep, timestep]) + for node in existing_nodes: + input_dict[node] = node.get(tr_scene, state[node.type]) + + return input_dict + + def get_scene_graph(self, + timestep, + attention_radius=None, + edge_addition_filter=None, + edge_removal_filter=None) -> SceneGraph: + """ + Returns the Scene Graph for a given timestep. If the Temporal Scene Graph was pre calculated, + the temporal scene graph is sliced. Otherwise the scene graph is calculated on the spot. + + :param timestep: Timestep for which the scene graph is returned. + :param attention_radius: Attention radius for each node type permutation. (Only online) + :param edge_addition_filter: Filter for adding edges (Only online) + :param edge_removal_filter: Filter for removing edges (Only online) + :return: Scene Graph for given timestep. + """ + if self.temporal_scene_graph is None: + timestep_range = np.array([timestep - len(edge_removal_filter), timestep]) + node_pos_dict = dict() + present_nodes = self.present_nodes(np.array([timestep])) + + for node in present_nodes[timestep]: + node_pos_dict[node] = np.squeeze(node.get(timestep_range, {'position': ['x', 'y']})) + tsg = TemporalSceneGraph.create_from_temp_scene_dict(node_pos_dict, + attention_radius, + duration=(len(edge_removal_filter) + 1), + edge_addition_filter=edge_addition_filter, + edge_removal_filter=edge_removal_filter + ) + + return tsg.to_scene_graph(t=len(edge_removal_filter), + t_hist=len(edge_removal_filter), + t_fut=len(edge_addition_filter)) + else: + return self.temporal_scene_graph.to_scene_graph(timestep, + len(edge_removal_filter), + len(edge_addition_filter)) + + def calculate_scene_graph(self, + attention_radius, + edge_addition_filter=None, + edge_removal_filter=None) -> None: + """ + Calculate the Temporal Scene Graph for the entire Scene. + + :param attention_radius: Attention radius for each node type permutation. + :param edge_addition_filter: Filter for adding edges. + :param edge_removal_filter: Filter for removing edges. + :return: None + """ + timestep_range = np.array([0, self.timesteps-1]) + node_pos_dict = dict() + + for node in self.nodes: + if type(node) is MultiNode: + node_pos_dict[node] = np.squeeze(node.get_all(timestep_range, {'position': ['x', 'y']})) + else: + node_pos_dict[node] = np.squeeze(node.get(timestep_range, {'position': ['x', 'y']})) + + self.temporal_scene_graph = TemporalSceneGraph.create_from_temp_scene_dict(node_pos_dict, + attention_radius, + duration=self.timesteps, + edge_addition_filter=edge_addition_filter, + edge_removal_filter=edge_removal_filter) + + def duration(self): + """ + Calculates the duration of the scene. + + :return: Duration of the scene in s. + """ + return self.timesteps * self.dt + + def present_nodes(self, + timesteps, + type=None, + min_history_timesteps=0, + min_future_timesteps=0, + return_robot=True) -> dict: + """ + Finds all present nodes in the scene at a given timestemp + + :param timesteps: Timestep(s) for which all present nodes should be returned + :param type: Node type which should be returned. If None all node types are returned. + :param min_history_timesteps: Minimum history timesteps of a node to be returned. + :param min_future_timesteps: Minimum future timesteps of a node to be returned. + :param return_robot: Return a node if it is the robot. + :return: Dictionary with timesteps as keys and list of nodes as value. + """ + + present_nodes = {} + + for node in self.nodes: + if node.is_robot and not return_robot: + continue + if type is None or node.type == type: + lower_bound = timesteps - min_history_timesteps + upper_bound = timesteps + min_future_timesteps + mask = (node.first_timestep <= lower_bound) & (upper_bound <= node.last_timestep) + if mask.any(): + timestep_indices_present = np.nonzero(mask)[0] + for timestep_index_present in timestep_indices_present: + if timesteps[timestep_index_present] in present_nodes.keys(): + present_nodes[timesteps[timestep_index_present]].append(node) + else: + present_nodes[timesteps[timestep_index_present]] = [node] + + return present_nodes + + def get_nodes_clipped_at_time(self, timesteps, state): + clipped_nodes = list() + + existing_nodes = self.present_nodes(timesteps) + all_nodes = set().union(*existing_nodes.values()) + if not all_nodes: + return clipped_nodes + + tr_scene = np.array([timesteps.min(), timesteps.max()]) + data_header_memo = dict() + for node in all_nodes: + if isinstance(node, MultiNode): + copied_node = copy.deepcopy(node.get_node_at_timesteps(tr_scene)) + copied_node.id = self.robot.id + else: + copied_node = copy.deepcopy(node) + + clipped_value = node.get(tr_scene, state[node.type]) + + if node.type not in data_header_memo: + data_header = list() + for quantity, values in state[node.type].items(): + for value in values: + data_header.append((quantity, value)) + + data_header_memo[node.type] = data_header + + copied_node.overwrite_data(clipped_value, data_header_memo[node.type]) + copied_node.first_timestep = tr_scene[0] + + clipped_nodes.append(copied_node) + + return clipped_nodes + + def sample_timesteps(self, batch_size, min_future_timesteps=0) -> np.ndarray: + """ + Sample a batch size of possible timesteps for the scene. + + :param batch_size: Number of timesteps to sample. + :param min_future_timesteps: Minimum future timesteps in the scene for a timestep to be returned. + :return: Numpy Array of sampled timesteps. + """ + if batch_size > self.timesteps: + batch_size = self.timesteps + return np.random.choice(np.arange(0, self.timesteps-min_future_timesteps), size=batch_size, replace=False) + + def augment(self): + if self.aug_func is not None: + scene_aug = np.random.choice(self.augmented) + scene_aug.temporal_scene_graph = self.temporal_scene_graph + return scene_aug + else: + return self + + def get_node_by_id(self, id): + for node in self.nodes: + if node.id == id: + return node + + def __repr__(self): + return f"Scene: Duration: {self.duration()}s," \ + f" Nodes: {len(self.nodes)}," \ + f" Map: {'Yes' if self.map is not None else 'No'}." diff --git a/diffstack/modules/predictors/trajectron_utils/environment/scene_graph.py b/diffstack/modules/predictors/trajectron_utils/environment/scene_graph.py new file mode 100644 index 0000000..1113bd4 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/environment/scene_graph.py @@ -0,0 +1,493 @@ +import numpy as np +from scipy.spatial.distance import pdist, squareform +import scipy.signal as ss +from collections import defaultdict +import warnings +from .node import Node + + +class Edge(object): + def __init__(self, curr_node, other_node): + self.id = self.get_edge_id(curr_node, other_node) + self.type = self.get_edge_type(curr_node, other_node) + self.curr_node = curr_node + self.other_node = other_node + + @staticmethod + def get_edge_id(n1, n2): + raise NotImplementedError("Use one of the Edge subclasses!") + + @staticmethod + def get_str_from_types(nt1, nt2): + raise NotImplementedError("Use one of the Edge subclasses!") + + @staticmethod + def get_edge_type(n1, n2): + raise NotImplementedError("Use one of the Edge subclasses!") + + def __eq__(self, other): + return (isinstance(other, self.__class__) + and self.id == other.id) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return self.id + + +class UndirectedEdge(Edge): + def __init__(self, curr_node, other_node): + super(UndirectedEdge, self).__init__(curr_node, other_node) + + @staticmethod + def get_edge_id(n1, n2): + return '-'.join(sorted([str(n1), str(n2)])) + + @staticmethod + def get_str_from_types(nt1, nt2): + return '-'.join(sorted([nt1.name, nt2.name])) + + @staticmethod + def get_edge_type(n1, n2): + return '-'.join(sorted([n1.type.name, n2.type.name])) + + +class DirectedEdge(Edge): + def __init__(self, curr_node, other_node): + super(DirectedEdge, self).__init__(curr_node, other_node) + + @staticmethod + def get_edge_id(n1, n2): + return '->'.join([str(n1), str(n2)]) + + @staticmethod + def get_str_from_types(nt1, nt2): + return '->'.join([nt1.name, nt2.name]) + + @staticmethod + def get_edge_type(n1, n2): + return '->'.join([n1.type.name, n2.type.name]) + + +class TemporalSceneGraph(object): + def __init__(self, + edge_radius, + nodes=None, + adj_cube=np.zeros((1, 0, 0)), + weight_cube=np.zeros((1, 0, 0)), + node_type_mat=np.zeros((0, 0)), + edge_scaling=None): + self.edge_radius = edge_radius + self.nodes = nodes + if nodes is None: + self.nodes = np.array([]) + self.adj_cube = adj_cube + self.weight_cube = weight_cube + self.node_type_mat = node_type_mat + self.adj_mat = np.max(self.adj_cube, axis=0).clip(max=1.0) + self.edge_scaling = edge_scaling + self.node_index_lookup = None + self.calculate_node_index_lookup() + + def calculate_node_index_lookup(self): + node_index_lookup = dict() + for i, node in enumerate(self.nodes): + node_index_lookup[node] = i + + self.node_index_lookup = node_index_lookup + + def get_num_edges(self, t=0): + return np.sum(self.adj_cube[t]) // 2 + + def get_index(self, node): + return self.node_index_lookup[node] + + @classmethod + def create_from_temp_scene_dict(cls, + scene_temp_dict, + attention_radius, + duration=1, + edge_addition_filter=None, + edge_removal_filter=None, + online=False): + """ + Construct a spatiotemporal graph from node positions in a dataset. + + :param scene_temp_dict: Dict with all nodes in scene as keys and np.ndarray with positions as value + :param attention_radius: Attention radius dict. + :param duration: Temporal duration of the graph. + :param edge_addition_filter: - + :param edge_removal_filter: - + :return: TemporalSceneGraph + """ + + nodes = scene_temp_dict.keys() + N = len(nodes) + total_timesteps = duration + + if N == 0: + return TemporalSceneGraph(attention_radius) + + position_cube = np.full((total_timesteps, N, 2), np.nan) + + adj_cube = np.zeros((total_timesteps, N, N), dtype=np.int8) + dist_cube = np.zeros((total_timesteps, N, N), dtype=np.float) + + node_type_mat = np.zeros((N, N), dtype=np.int8) + node_attention_mat = np.zeros((N, N), dtype=np.float) + + for node_idx, node in enumerate(nodes): + if online: + # RingBuffers do not have a fixed constant size. Instead, they grow up to their capacity. Thus, + # we need to fill the values preceding the RingBuffer values with NaNs to make them fill the + # position_cube. + position_cube[-scene_temp_dict[node].shape[0]:, node_idx] = scene_temp_dict[node] + else: + position_cube[:, node_idx] = scene_temp_dict[node] + + node_type_mat[:, node_idx] = node.type.value + for node_idx_from, node_from in enumerate(nodes): + node_attention_mat[node_idx_from, node_idx] = attention_radius[(node_from.type, node.type)] + + np.fill_diagonal(node_type_mat, 0) + + for timestep in range(position_cube.shape[0]): + dists = squareform(pdist(position_cube[timestep], metric='euclidean')) + + # Put a 1 for all agent pairs which are closer than the edge_radius. + # Can produce a warning as dists can be nan if no data for node is available. + # This is accepted as nan <= x evaluates to False + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + adj_matrix = (dists <= node_attention_mat).astype(np.int8) * node_type_mat + + # Remove self-loops. + np.fill_diagonal(adj_matrix, 0) + + adj_cube[timestep] = adj_matrix + dist_cube[timestep] = dists + + dist_cube[np.isnan(dist_cube)] = 0. + weight_cube = np.divide(1., + dist_cube, + out=np.zeros_like(dist_cube), + where=(dist_cube > 0.)) + edge_scaling = None + if edge_addition_filter is not None and edge_removal_filter is not None: + edge_scaling = cls.calculate_edge_scaling(adj_cube, edge_addition_filter, edge_removal_filter) + tsg = cls(attention_radius, + np.array(list(nodes)), + adj_cube, weight_cube, + node_type_mat, + edge_scaling=edge_scaling) + return tsg + + @staticmethod + def calculate_edge_scaling(adj_cube, edge_addition_filter, edge_removal_filter): + shifted_right = np.pad(adj_cube, ((len(edge_addition_filter) - 1, 0), (0, 0), (0, 0)), 'constant', constant_values=0) + + new_edges = np.minimum( + ss.convolve(shifted_right, np.reshape(edge_addition_filter, (-1, 1, 1)), 'full'), 1. + )[(len(edge_addition_filter) - 1):-(len(edge_addition_filter) - 1)] + + new_edges[adj_cube == 0] = 0 + + result = np.minimum( + ss.convolve(new_edges, np.reshape(edge_removal_filter, (-1, 1, 1)), 'full'), 1. + )[:-(len(edge_removal_filter) - 1)] + + return result + + def to_scene_graph(self, t, t_hist=0, t_fut=0): + """ + Creates a Scene Graph from a Temporal Scene Graph + + :param t: Time in Temporal Scene Graph for which Scene Graph is created. + :param t_hist: Number of history timesteps which are considered to form edges in Scene Graph. + :param t_fut: Number of future timesteps which are considered to form edges in Scene Graph. + :return: SceneGraph + """ + lower_t = np.clip(t-t_hist, a_min=0, a_max=None) + higher_t = np.clip(t + t_fut + 1, a_min=None, a_max=self.adj_cube.shape[0] + 1) + adj_mat = np.max(self.adj_cube[lower_t:higher_t], axis=0) + weight_mat = np.max(self.weight_cube[lower_t:higher_t], axis=0) + return SceneGraph(self.edge_radius, + self.nodes, + adj_mat, + weight_mat, + self.node_type_mat, + self.node_index_lookup, + edge_scaling=self.edge_scaling[t] if self.edge_scaling is not None else None) + + +class SceneGraph(object): + def __init__(self, + edge_radius, + nodes=None, + adj_mat=np.zeros((0, 0)), + weight_mat=np.zeros((0, 0)), + node_type_mat=np.zeros((0, 0)), + node_index_lookup=None, + edge_scaling=None): + self.edge_radius = edge_radius + self.nodes = nodes + if nodes is None: + self.nodes = np.array([]) + self.node_type_mat = node_type_mat + self.adj_mat = adj_mat + self.weight_mat = weight_mat + self.edge_scaling = edge_scaling + self.node_index_lookup = node_index_lookup + + def get_index(self, node): + return self.node_index_lookup[node] + + def get_num_edges(self): + return np.sum(self.adj_mat) // 2 + + def get_neighbors(self, node, node_type): + """ + Get all neighbors of a node. + + :param node: Node for which all neighbors are returned. + :param node_type: Specifies node types which are returned. + :return: List of all neighbors. + """ + node_index = self.get_index(node) + connection_mask = self.get_connection_mask(node_index) + mask = ((self.node_type_mat[node_index] == node_type.value) * connection_mask) + return self.nodes[mask] + + def get_edge_scaling(self, node=None): + if node is None: + return self.edge_scaling + else: + node_index = self.get_index(node) + connection_mask = self.get_connection_mask(node_index) + return self.edge_scaling[node_index, connection_mask] + + def get_edge_weight(self, node=None): + if node is None: + return self.weight_mat + else: + node_index = self.get_index(node) + connection_mask = self.get_connection_mask(node_index) + return self.weight_mat[node_index, connection_mask] + + def get_connection_mask(self, node_index): + if self.edge_scaling is None: # We do not use edge scaling + return self.adj_mat[node_index] > 0. + else: + return self.edge_scaling[node_index] > 1e-2 + + def __sub__(self, other): + new_nodes = [node for node in self.nodes if node not in other.nodes] + removed_nodes = [node for node in other.nodes if node not in self.nodes] + + our_types = set(node.type for node in self.nodes) + other_types = set(node.type for node in other.nodes) + all_node_types = our_types | other_types + + new_neighbors = defaultdict(lambda: defaultdict(set)) + for node in self.nodes: + if node in removed_nodes: + continue + + if node in other.nodes: + for node_type in all_node_types: + new_items = set(self.get_neighbors(node, node_type)) - set(other.get_neighbors(node, node_type)) + if len(new_items) > 0: + new_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = new_items + else: + for node_type in our_types: + neighbors = self.get_neighbors(node, node_type) + if len(neighbors) > 0: + new_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors) + + removed_neighbors = defaultdict(lambda: defaultdict(set)) + for node in other.nodes: + if node in removed_nodes: + continue + + if node in self.nodes: + for node_type in all_node_types: + removed_items = set(other.get_neighbors(node, node_type)) - set(self.get_neighbors(node, node_type)) + if len(removed_items) > 0: + removed_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = removed_items + else: + for node_type in other_types: + neighbors = other.get_neighbors(node, node_type) + if len(neighbors) > 0: + removed_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors) + + return new_nodes, removed_nodes, new_neighbors, removed_neighbors + + +if __name__ == '__main__': + from environment import NodeTypeEnum + import time + + # # # # # # # # # # # # # # # # # + # Testing edge mask calculation # + # # # # # # # # # # # # # # # # # + B = np.array([[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0], + [1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0]])[:, :, np.newaxis, np.newaxis] + print(B.shape) + + edge_addition_filter = [0.25, 0.5, 0.75, 1.0] + edge_removal_filter = [1.0, 0.5, 0.0] + for i in range(B.shape[0]): + A = B[i] # (time, N, N) + + print(A[:, 0, 0]) + + start = time.time() + new_edges = np.minimum(ss.convolve(A, np.reshape(edge_addition_filter, (-1, 1, 1)), 'full'), 1.)[(len(edge_addition_filter) - 1):] + old_edges = np.minimum(ss.convolve(A, np.reshape(edge_removal_filter, (-1, 1, 1)), 'full'), 1.)[:-(len(edge_removal_filter) - 1)] + res = np.minimum(new_edges + old_edges, 1.)[:, 0, 0] + end = time.time() + print(end - start) + print(res) + + start = time.time() + res = TemporalSceneGraph.calculate_edge_scaling(A, edge_addition_filter, edge_removal_filter)[:, 0, 0] + end = time.time() + print(end - start) + print(res) + + print('-'*40) + + # # # # # # # # # # # # # # # + # Testing graph subtraction # + # # # # # # # # # # # # # # # + print('\n' + '-' * 40 + '\n') + + node_type_list = ['PEDESTRIAN', + 'BICYCLE', + 'VEHICLE'] + nte = NodeTypeEnum(node_type_list) + + attention_radius = dict() + attention_radius[(nte.PEDESTRIAN, nte.PEDESTRIAN)] = 5.0 + attention_radius[(nte.PEDESTRIAN, nte.VEHICLE)] = 20.0 + attention_radius[(nte.PEDESTRIAN, nte.BICYCLE)] = 10.0 + attention_radius[(nte.VEHICLE, nte.PEDESTRIAN)] = 20.0 + attention_radius[(nte.VEHICLE, nte.VEHICLE)] = 20.0 + attention_radius[(nte.VEHICLE, nte.BICYCLE)] = 20.0 + attention_radius[(nte.BICYCLE, nte.PEDESTRIAN)] = 10.0 + attention_radius[(nte.BICYCLE, nte.VEHICLE)] = 20.0 + attention_radius[(nte.BICYCLE, nte.BICYCLE)] = 10.0 + + scene_dict1 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([0, 1])} + sg1 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict1, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + scene_dict2 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([1, 1])} + sg2 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict2, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + new_nodes, removed_nodes, new_neighbors, removed_neighbors = sg2 - sg1 + print('New Nodes:', new_nodes) + print('Removed Nodes:', removed_nodes) + print('New Neighbors:', new_neighbors) + print('Removed Neighbors:', removed_neighbors) + + # # # # # # # # # # # # # # # + print('\n' + '-' * 40 + '\n') + + scene_dict1 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([0, 1])} + sg1 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict1, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + scene_dict2 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([1, 1]), + Node(nte.PEDESTRIAN, node_id='3'): np.array([20, 1])} + sg2 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict2, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + new_nodes, removed_nodes, new_neighbors, removed_neighbors = sg2 - sg1 + print('New Nodes:', new_nodes) + print('Removed Nodes:', removed_nodes) + print('New Neighbors:', new_neighbors) + print('Removed Neighbors:', removed_neighbors) + + # # # # # # # # # # # # # # # + print('\n' + '-' * 40 + '\n') + + scene_dict1 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([0, 1])} + sg1 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict1, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + scene_dict2 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([1, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([10, 1]), + Node(nte.PEDESTRIAN, node_id='3'): np.array([20, 1])} + sg2 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict2, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + new_nodes, removed_nodes, new_neighbors, removed_neighbors = sg2 - sg1 + print('New Nodes:', new_nodes) + print('Removed Nodes:', removed_nodes) + print('New Neighbors:', new_neighbors) + print('Removed Neighbors:', removed_neighbors) + + # # # # # # # # # # # # # # # + print('\n' + '-' * 40 + '\n') + + scene_dict1 = {Node(nte.PEDESTRIAN, node_id='1'): np.array([0, 0]), + Node(nte.PEDESTRIAN, node_id='2'): np.array([0, 1])} + sg1 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict1, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + scene_dict2 = {Node(nte.PEDESTRIAN, node_id='2'): np.array([10, 1]), + Node(nte.PEDESTRIAN, node_id='3'): np.array([12, 1]), + Node(nte.PEDESTRIAN, node_id='4'): np.array([13, 1])} + sg2 = TemporalSceneGraph.create_from_temp_scene_dict( + scene_dict2, + attention_radius=attention_radius, + duration=1, + edge_addition_filter=[0.25, 0.5, 0.75, 1.0], + edge_removal_filter=[1.0, 0.0]).to_scene_graph(t=0) + + new_nodes, removed_nodes, new_neighbors, removed_neighbors = sg2 - sg1 + print('New Nodes:', new_nodes) + print('Removed Nodes:', removed_nodes) + print('New Neighbors:', new_neighbors) + print('Removed Neighbors:', removed_neighbors) diff --git a/diffstack/modules/predictors/trajectron_utils/model/__init__.py b/diffstack/modules/predictors/trajectron_utils/model/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/__init__.py @@ -0,0 +1 @@ + diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/__init__.py b/diffstack/modules/predictors/trajectron_utils/model/components/__init__.py new file mode 100644 index 0000000..116a37c --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/__init__.py @@ -0,0 +1,4 @@ +from .discrete_latent import DiscreteLatent +from .gmm2d import GMM2D +from .map_encoder import CNNMapEncoder +from .additive_attention import AdditiveAttention, TemporallyBatchedAdditiveAttention diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/additive_attention.py b/diffstack/modules/predictors/trajectron_utils/model/components/additive_attention.py new file mode 100644 index 0000000..9362324 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/additive_attention.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AdditiveAttention(nn.Module): + # Implementing the attention module of Bahdanau et al. 2015 where + # score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1)) + def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None): + super(AdditiveAttention, self).__init__() + + if internal_dim is None: + internal_dim = int((encoder_hidden_state_dim + decoder_hidden_state_dim) / 2) + + self.w1 = nn.Linear(encoder_hidden_state_dim, internal_dim, bias=False) + self.w2 = nn.Linear(decoder_hidden_state_dim, internal_dim, bias=False) + self.v = nn.Linear(internal_dim, 1, bias=False) + + def score(self, encoder_state, decoder_state): + # encoder_state is of shape (batch, enc_dim) + # decoder_state is of shape (batch, dec_dim) + # return value should be of shape (batch, 1) + return self.v(torch.tanh(self.w1(encoder_state) + self.w2(decoder_state))) + + def forward(self, encoder_states, decoder_state): + # encoder_states is of shape (batch, num_enc_states, enc_dim) + # decoder_state is of shape (batch, dec_dim) + score_vec = torch.cat([self.score(encoder_states[:, i], decoder_state) for i in range(encoder_states.shape[1])], + dim=1) + # score_vec is of shape (batch, num_enc_states) + + attention_probs = torch.unsqueeze(F.softmax(score_vec, dim=1), dim=2) + # attention_probs is of shape (batch, num_enc_states, 1) + + final_context_vec = torch.sum(attention_probs * encoder_states, dim=1) + # final_context_vec is of shape (batch, enc_dim) + + return final_context_vec, attention_probs + + +class TemporallyBatchedAdditiveAttention(AdditiveAttention): + # Implementing the attention module of Bahdanau et al. 2015 where + # score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1)) + def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None): + super(TemporallyBatchedAdditiveAttention, self).__init__(encoder_hidden_state_dim, + decoder_hidden_state_dim, + internal_dim) + + def score(self, encoder_state, decoder_state): + # encoder_state is of shape (batch, num_enc_states, max_time, enc_dim) + # decoder_state is of shape (batch, max_time, dec_dim) + # return value should be of shape (batch, num_enc_states, max_time, 1) + return self.v(torch.tanh(self.w1(encoder_state) + torch.unsqueeze(self.w2(decoder_state), dim=1))) + + def forward(self, encoder_states, decoder_state): + # encoder_states is of shape (batch, num_enc_states, max_time, enc_dim) + # decoder_state is of shape (batch, max_time, dec_dim) + score_vec = self.score(encoder_states, decoder_state) + # score_vec is of shape (batch, num_enc_states, max_time, 1) + + attention_probs = F.softmax(score_vec, dim=1) + # attention_probs is of shape (batch, num_enc_states, max_time, 1) + + final_context_vec = torch.sum(attention_probs * encoder_states, dim=1) + # final_context_vec is of shape (batch, max_time, enc_dim) + + return final_context_vec, torch.squeeze(torch.transpose(attention_probs, 1, 2), dim=3) diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/discrete_latent.py b/diffstack/modules/predictors/trajectron_utils/model/components/discrete_latent.py new file mode 100644 index 0000000..e0509e6 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/discrete_latent.py @@ -0,0 +1,109 @@ +import torch +import torch.distributions as td +import numpy as np +from diffstack.modules.predictors.trajectron_utils.model.model_utils import ModeKeys + + +class DiscreteLatent(object): + def __init__(self, hyperparams, device): + self.hyperparams = hyperparams + self.z_dim = hyperparams['N'] * hyperparams['K'] + self.N = hyperparams['N'] + self.K = hyperparams['K'] + self.kl_min = hyperparams['kl_min'] + self.device = device + self.temp = None # filled in by MultimodalGenerativeCVAE.set_annealing_params + self.z_logit_clip = None # filled in by MultimodalGenerativeCVAE.set_annealing_params + self.p_dist = None # filled in by MultimodalGenerativeCVAE.encoder + self.q_dist = None # filled in by MultimodalGenerativeCVAE.encoder + + def dist_from_h(self, h, mode): + logits_separated = torch.reshape(h, (-1, self.N, self.K)) + logits_separated_mean_zero = logits_separated - torch.mean(logits_separated, dim=-1, keepdim=True) + if self.z_logit_clip is not None and mode == ModeKeys.TRAIN: + c = self.z_logit_clip + logits = torch.clamp(logits_separated_mean_zero, min=-c, max=c) + else: + logits = logits_separated_mean_zero + + return td.OneHotCategorical(logits=logits) + + def sample_q(self, num_samples, mode): + bs = self.p_dist.probs.size()[0] + num_components = self.N * self.K + z_NK = torch.from_numpy(self.all_one_hot_combinations(self.N, self.K)).float().to(self.device).repeat(num_samples, bs) + return torch.reshape(z_NK, (num_samples * num_components, -1, self.z_dim)) + + def sample_p(self, num_samples, mode, most_likely_z=False, full_dist=True, all_z_sep=False): + num_components = 1 + if full_dist: + bs = self.p_dist.probs.size()[0] + z_NK = torch.from_numpy(self.all_one_hot_combinations(self.N, self.K)).float().to(self.device).repeat(num_samples, bs) + num_components = self.K ** self.N + k = num_samples * num_components + elif all_z_sep: + bs = self.p_dist.probs.size()[0] + z_NK = torch.from_numpy(self.all_one_hot_combinations(self.N, self.K)).float().to(self.device).repeat(1, bs) + k = self.K ** self.N + num_samples = k + elif most_likely_z: + # Sampling the most likely z from p(z|x). + eye_mat = torch.eye(self.p_dist.event_shape[-1], device=self.device) + argmax_idxs = torch.argmax(self.p_dist.probs, dim=2) + z_NK = torch.unsqueeze(eye_mat[argmax_idxs], dim=0).expand(num_samples, -1, -1, -1) + k = num_samples + else: + z_NK = self.p_dist.sample((num_samples,)) + k = num_samples + + if mode == ModeKeys.PREDICT: + return torch.reshape(z_NK, (k, -1, self.N * self.K)), num_samples, num_components + else: + return torch.reshape(z_NK, (k, -1, self.N * self.K)) + + def kl_q_p(self, log_writer=None, prefix=None, curr_iter=None): + kl_separated = td.kl_divergence(self.q_dist, self.p_dist) + if len(kl_separated.size()) < 2: + kl_separated = torch.unsqueeze(kl_separated, dim=0) + + kl_minibatch = torch.mean(kl_separated, dim=0, keepdim=True) + + if log_writer is not None: + log_writer.add_scalar(prefix + '/true_kl', torch.sum(kl_minibatch), curr_iter) + + if self.kl_min > 0: + kl_lower_bounded = torch.clamp(kl_minibatch, min=self.kl_min) + kl = torch.sum(kl_lower_bounded) + else: + kl = torch.sum(kl_minibatch) + + return kl + + def q_log_prob(self, z): + k = z.size()[0] + z_NK = torch.reshape(z, [k, -1, self.N, self.K]) + return torch.sum(self.q_dist.log_prob(z_NK), dim=2) + + def p_log_prob(self, z): + k = z.size()[0] + z_NK = torch.reshape(z, [k, -1, self.N, self.K]) + return torch.sum(self.p_dist.log_prob(z_NK), dim=2) + + def get_p_dist_probs(self): + return self.p_dist.probs + + @staticmethod + def all_one_hot_combinations(N, K): + return np.eye(K).take(np.reshape(np.indices([K] * N), [N, -1]).T, axis=0).reshape(-1, N * K) # [K**N, N*K] + + def summarize_for_tensorboard(self, log_writer, prefix, curr_iter): + log_writer.add_histogram(prefix + "/latent/p_z_x", self.p_dist.probs, curr_iter) + log_writer.add_histogram(prefix + "/latent/q_z_xy", self.q_dist.probs, curr_iter) + log_writer.add_histogram(prefix + "/latent/p_z_x_logits", self.p_dist.logits, curr_iter) + log_writer.add_histogram(prefix + "/latent/q_z_xy_logits", self.q_dist.logits, curr_iter) + if self.z_dim <= 9: + for i in range(self.N): + for j in range(self.K): + log_writer.add_histogram(prefix + "/latent/q_z_xy_logit{0}{1}".format(i, j), + self.q_dist.logits[:, i, j], + curr_iter) diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py b/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py new file mode 100644 index 0000000..7d47ba0 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py @@ -0,0 +1,181 @@ +import torch +import torch.distributions as td +import numpy as np +from diffstack.modules.predictors.trajectron_utils.model.model_utils import to_one_hot + + +class GMM2D(td.Distribution): + r""" + Gaussian Mixture Model using 2D Multivariate Gaussians each of as N components: + Cholesky decompesition and affine transformation for sampling: + + .. math:: Z \sim N(0, I) + + .. math:: S = \mu + LZ + + .. math:: S \sim N(\mu, \Sigma) \rightarrow N(\mu, LL^T) + + where :math:`L = chol(\Sigma)` and + + .. math:: \Sigma = \left[ {\begin{array}{cc} \sigma^2_x & \rho \sigma_x \sigma_y \\ \rho \sigma_x \sigma_y & \sigma^2_y \\ \end{array} } \right] + + such that + + .. math:: L = chol(\Sigma) = \left[ {\begin{array}{cc} \sigma_x & 0 \\ \rho \sigma_y & \sigma_y \sqrt{1-\rho^2} \\ \end{array} } \right] + + :param log_pis: Log Mixing Proportions :math:`log(\pi)`. [..., N] + :param mus: Mixture Components mean :math:`\mu`. [..., N * 2] + :param log_sigmas: Log Standard Deviations :math:`log(\sigma_d)`. [..., N * 2] + :param corrs: Cholesky factor of correlation :math:`\rho`. [..., N] + :param clip_lo: Clips the lower end of the standard deviation. + :param clip_hi: Clips the upper end of the standard deviation. + """ + def __init__(self, log_pis, mus, log_sigmas, corrs): + super(GMM2D, self).__init__(batch_shape=log_pis.shape[0], event_shape=log_pis.shape[1:], validate_args=False) + self.components = log_pis.shape[-1] + self.dimensions = 2 + self.device = log_pis.device + + log_pis = torch.clamp(log_pis, min=-1e5) + self.log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) # [..., N] + self.mus = self.reshape_to_components(mus) # [..., N, 2] + self.log_sigmas = self.reshape_to_components(log_sigmas) # [..., N, 2] + self.sigmas = torch.exp(self.log_sigmas) # [..., N, 2] + self.one_minus_rho2 = 1 - corrs**2 # [..., N] + self.one_minus_rho2 = torch.clamp(self.one_minus_rho2, min=1e-5, max=1) # otherwise log can be nan + self.corrs = corrs # [..., N] + + self.L = torch.stack([torch.stack([self.sigmas[..., 0], torch.zeros_like(self.log_pis)], dim=-1), + torch.stack([self.sigmas[..., 1] * self.corrs, + self.sigmas[..., 1] * torch.sqrt(self.one_minus_rho2)], + dim=-1)], + dim=-2) + + if log_pis.numel() == 0: + self.pis_cat_dist = None + else: + self.pis_cat_dist = td.Categorical(logits=log_pis) + + def set_device(self, device): + self.device = device + self.log_pis = self.log_pis.to(device) + self.mus = self.mus.to(device) + self.log_sigmas = self.log_sigmas.to(device) + self.sigmas = self.sigmas.to(device) + self.one_minus_rho2 = self.one_minus_rho2.to(device) + self.corrs = self.corrs.to(device) + self.L = self.L.to(device) + if self.pis_cat_dist is not None: + self.pis_cat_dist = td.Categorical(logits=self.pis_cat_dist.logits.to(device)) + + @classmethod + def from_log_pis_mus_cov_mats(cls, log_pis, mus, cov_mats): + corrs_sigma12 = cov_mats[..., 0, 1] + sigma_1 = torch.clamp(cov_mats[..., 0, 0], min=1e-8) + sigma_2 = torch.clamp(cov_mats[..., 1, 1], min=1e-8) + sigmas = torch.stack([torch.sqrt(sigma_1), torch.sqrt(sigma_2)], dim=-1) + log_sigmas = torch.log(sigmas) + corrs = corrs_sigma12 / (torch.prod(sigmas, dim=-1)) + return cls(log_pis, mus, log_sigmas, corrs) + + def rsample(self, sample_shape=torch.Size()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. + + :param sample_shape: Shape of the samples + :return: Samples from the GMM. + """ + mvn_samples = (self.mus + + torch.squeeze( + torch.matmul(self.L, + torch.unsqueeze( + torch.randn(size=sample_shape + self.mus.shape, device=self.device), + dim=-1) + ), + dim=-1)) + if self.pis_cat_dist is None: + raise NotImplementedError # not tested + component_cat_samples = torch.zeros(sample_shape) + else: + component_cat_samples = self.pis_cat_dist.sample(sample_shape) + selector = torch.unsqueeze(to_one_hot(component_cat_samples, self.components), dim=-1) + return torch.sum(mvn_samples*selector, dim=-2) + + def log_prob(self, value): + r""" + Calculates the log probability of a value using the PDF for bivariate normal distributions: + + .. math:: + f(x | \mu, \sigma, \rho)={\frac {1}{2\pi \sigma _{x}\sigma _{y}{\sqrt {1-\rho ^{2}}}}}\exp + \left(-{\frac {1}{2(1-\rho ^{2})}}\left[{\frac {(x-\mu _{x})^{2}}{\sigma _{x}^{2}}}+ + {\frac {(y-\mu _{y})^{2}}{\sigma _{y}^{2}}}-{\frac {2\rho (x-\mu _{x})(y-\mu _{y})} + {\sigma _{x}\sigma _{y}}}\right]\right) + + :param value: The log probability density function is evaluated at those values. + :return: Log probability + """ + # x: [..., 2] + value = torch.unsqueeze(value, dim=-2) # [..., 1, 2] + dx = value - self.mus[..., :value.shape[-1]] # [..., N, 2] + + exp_nominator = ((torch.sum((dx/self.sigmas)**2, dim=-1) # first and second term of exp nominator + - 2*self.corrs*torch.prod(dx, dim=-1)/torch.prod(self.sigmas, dim=-1))) # [..., N] + + component_log_p = -(2*np.log(2*np.pi) + + torch.log(self.one_minus_rho2) + + 2*torch.sum(self.log_sigmas, dim=-1) + + exp_nominator/self.one_minus_rho2) / 2 + + return torch.logsumexp(self.log_pis + component_log_p, dim=-1) + + def get_for_node(self, n): + return self.__class__(self.log_pis[:, n:n+1], self.mus[:, n:n+1], + self.log_sigmas[:, n:n+1], self.corrs[:, n:n+1]) + + def get_for_node_at_time(self, n, t): + return self.__class__(self.log_pis[:, n:n+1, t:t+1], self.mus[:, n:n+1, t:t+1], + self.log_sigmas[:, n:n+1, t:t+1], self.corrs[:, n:n+1, t:t+1]) + + def mode(self): + """ + Calculates the mode of the GMM by calculating probabilities of a 2D mesh grid + + :param required_accuracy: Accuracy of the meshgrid + :return: Mode of the GMM + """ + if self.mus.shape[-2] > 1: + samp, bs, time, comp, _ = self.mus.shape + assert samp == 1, "For taking the mode only one sample makes sense." + mode_node_list = [] + for n in range(bs): + mode_t_list = [] + for t in range(time): + nt_gmm = self.get_for_node_at_time(n, t) + x_min = self.mus[:, n, t, :, 0].min() + x_max = self.mus[:, n, t, :, 0].max() + y_min = self.mus[:, n, t, :, 1].min() + y_max = self.mus[:, n, t, :, 1].max() + search_grid = torch.stack(torch.meshgrid([torch.arange(x_min, x_max, 0.01), + torch.arange(y_min, y_max, 0.01)]), dim=2 + ).view(-1, 2).float().to(self.device) + + ll_score = nt_gmm.log_prob(search_grid) + argmax = torch.argmax(ll_score.squeeze(), dim=0) + mode_t_list.append(search_grid[argmax]) + mode_node_list.append(torch.stack(mode_t_list, dim=0)) + return torch.stack(mode_node_list, dim=0).unsqueeze(dim=0) + return torch.squeeze(self.mus, dim=-2) + + def reshape_to_components(self, tensor): + if len(tensor.shape) == 5: + return tensor + return torch.reshape(tensor, list(tensor.shape[:-1]) + [self.components, self.dimensions]) + + def get_covariance_matrix(self): + cov = self.corrs * torch.prod(self.sigmas, dim=-1) + E = torch.stack([torch.stack([self.sigmas[..., 0]**2, cov], dim=-1), + torch.stack([cov, self.sigmas[..., 1]**2], dim=-1)], + dim=-2) + return E diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/graph_attention.py b/diffstack/modules/predictors/trajectron_utils/model/components/graph_attention.py new file mode 100644 index 0000000..fc8d89a --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/graph_attention.py @@ -0,0 +1,58 @@ +import warnings +import math +import numbers +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init, Parameter + + +class GraphMultiTypeAttention(nn.Module): + def __init__(self, in_features, hidden_features, out_features, bias=True, types=1): + super(GraphMultiTypeAttention, self).__init__() + self.types = types + self.in_features = in_features + self.out_features = out_features + self.node_self_loop_weight = Parameter(torch.Tensor(hidden_features, in_features[0])) + + self.weight_per_type = nn.ParameterList() + for i in range(types): + self.weight_per_type.append(Parameter(torch.Tensor(hidden_features, in_features[i]))) + if bias: + self.bias = Parameter(torch.Tensor(hidden_features)) + else: + self.register_parameter('bias', None) + + self.linear_to_out = nn.Linear(hidden_features, out_features, bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + for weight in self.weight_per_type: + bound = 1 / math.sqrt(weight.size(1)) + init.uniform_(weight, -bound, bound) + bound = 1 / math.sqrt(self.node_self_loop_weight.size(1)) + init.uniform_(self.node_self_loop_weight, -bound, bound) + if self.bias is not None: + init.uniform_(self.bias, -bound, bound) + + def forward(self, inputs, types, edge_weights): + weight_list = list() + for i, type in enumerate(types): + weight_list.append((edge_weights[i] / len(edge_weights)) * self.weight_per_type[type].T) + weight_list.append(self.node_self_loop_weight.T) + weight = torch.cat(weight_list, dim=0) + stacked_input = torch.cat(inputs, dim=-1) + output = stacked_input.matmul(weight) + + output = output + + if self.bias is not None: + output += self.bias + + return torch.relu(self.linear_to_out(torch.relu(output))) + + def extra_repr(self): + return 'in_features={}, hidden_features={},, out_features={}, types={}, bias={}'.format( + self.in_features, self.hidden_features, self.out_features, self.types, self.bias is not None + ) \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/map_encoder.py b/diffstack/modules/predictors/trajectron_utils/model/components/map_encoder.py new file mode 100644 index 0000000..27d6e1d --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/components/map_encoder.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CNNMapEncoder(nn.Module): + def __init__(self, map_channels, hidden_channels, output_size, masks, strides, patch_size): + super(CNNMapEncoder, self).__init__() + self.convs = nn.ModuleList() + patch_size_x = patch_size[0] + patch_size[2] + patch_size_y = patch_size[1] + patch_size[3] + input_size = (map_channels, patch_size_x, patch_size_y) + x_dummy = torch.ones(input_size).unsqueeze(0) * torch.tensor(float('nan')) + + for i, hidden_size in enumerate(hidden_channels): + self.convs.append(nn.Conv2d(map_channels if i == 0 else hidden_channels[i-1], + hidden_channels[i], masks[i], + stride=strides[i])) + x_dummy = self.convs[i](x_dummy) + + self.fc = nn.Linear(x_dummy.numel(), output_size) + + def forward(self, x, training): + for conv in self.convs: + x = F.leaky_relu(conv(x), 0.2) + x = torch.flatten(x, start_dim=1) + x = self.fc(x) + return x diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/__init__.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/__init__.py new file mode 100644 index 0000000..6c6fe59 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/__init__.py @@ -0,0 +1,2 @@ +from .dataset import EnvironmentDataset, NodeTypeDataset +from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore, get_relative_robot_traj, batchable_dict diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py new file mode 100644 index 0000000..c0da26c --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py @@ -0,0 +1,314 @@ +import os +from torch.utils import data +import time +import torch +import numpy as np + +try: + from math import prod +except ImportError: + from functools import reduce # Required in Python 3 + import operator + + def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +from .preprocessing import get_node_timestep_data +from tqdm import tqdm +from diffstack.modules.predictors.trajectron_utils.environment import ( + EnvironmentMetadata, +) +from functools import partial +from pathos.multiprocessing import ProcessPool as Pool + + +class EnvironmentDataset(object): + def __init__( + self, + env, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + **kwargs, + ): + self.env = env + self.state = state + self.pred_state = pred_state + self.hyperparams = hyperparams + self.max_ht = self.hyperparams["maximum_history_length"] + self.max_ft = kwargs["min_future_timesteps"] + self.node_type_datasets = list() + self._augment = False + for node_type in env.NodeType: + if node_type not in hyperparams["pred_state"]: + continue + self.node_type_datasets.append( + NodeTypeDataset( + env, + node_type, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + **kwargs, + ) + ) + + @property + def augment(self): + return self._augment + + @augment.setter + def augment(self, value): + self._augment = value + for node_type_dataset in self.node_type_datasets: + node_type_dataset.augment = value + + def __iter__(self): + return iter(self.node_type_datasets) + + +def parallel_process_scene( + scene, + env_metadata, + node_type, + state, + pred_state, + edge_types, + max_ht, + max_ft, + node_freq_mult, + scene_freq_mult, + hyperparams, + augment, + nusc_maps, + kwargs, +): + results = list() + indexing_info = list() + + tsteps = np.arange(0, scene.timesteps) + present_node_dict = scene.present_nodes(tsteps, type=node_type, **kwargs) + + for t, nodes in present_node_dict.items(): + for node in nodes: + if augment: + scene_aug = scene.augment() + node_aug = scene.get_node_by_id(node.id) + + scene_data = get_node_timestep_data( + env_metadata, + scene_aug, + t, + node_aug, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + ) + else: + scene_data = get_node_timestep_data( + env_metadata, + scene, + t, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + ) + + results += [(scene_data, (scene, t, node))] + + indexing_info += [ + ( + scene.frequency_multiplier if scene_freq_mult else 1, + node.frequency_multiplier if node_freq_mult else 1, + ) + ] + + return (results, indexing_info) + + +class NodeTypeDataset(data.Dataset): + def __init__( + self, + env, + node_type, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + augment=False, + **kwargs, + ): + self.env = env + self.env_metadata = EnvironmentMetadata(env) + self.state = state + self.pred_state = pred_state + self.hyperparams = hyperparams + self.max_ht = self.hyperparams["maximum_history_length"] + self.max_ft = kwargs["min_future_timesteps"] + + self.augment = augment + + self.node_type = node_type + self.edge_types = [ + edge_type for edge_type in env.get_edge_types() if edge_type[0] is node_type + ] + self.index, self.data, self.data_origin = self.index_env( + node_freq_mult, scene_freq_mult, **kwargs + ) + + def index_env(self, node_freq_mult, scene_freq_mult, **kwargs): + num_cpus = kwargs["num_workers"] + del kwargs["num_workers"] + + rank = kwargs["rank"] + del kwargs["rank"] + + if num_cpus > 0: + with Pool(num_cpus) as pool: + indexed_scenes = list( + tqdm( + pool.imap( + partial( + parallel_process_scene, + env_metadata=self.env_metadata, + node_type=self.node_type, + state=self.state, + pred_state=self.pred_state, + edge_types=self.edge_types, + max_ht=self.max_ht, + max_ft=self.max_ft, + node_freq_mult=node_freq_mult, + scene_freq_mult=scene_freq_mult, + hyperparams=self.hyperparams, + augment=self.augment, + nusc_maps=self.env.nusc_maps, + kwargs=kwargs, + ), + self.env.scenes, + ), + desc=f"Indexing {self.node_type}s ({num_cpus} CPUs)", + total=len(self.env.scenes), + disable=(rank > 0), + ) + ) + else: + indexed_scenes = [ + parallel_process_scene( + scene, + env_metadata=self.env_metadata, + node_type=self.node_type, + state=self.state, + pred_state=self.pred_state, + edge_types=self.edge_types, + max_ht=self.max_ht, + max_ft=self.max_ft, + node_freq_mult=node_freq_mult, + scene_freq_mult=scene_freq_mult, + hyperparams=self.hyperparams, + augment=self.augment, + nusc_maps=self.env.nusc_maps, + kwargs=kwargs, + ) + for scene in self.env.scenes + ] + + results = list() + indexing_info = list() + for res in indexed_scenes: + results.extend(res[0]) + indexing_info.extend(res[1]) + + index = list() + for i, counts in enumerate(indexing_info): + total = prod(counts) + + index += [i] * total + + data, data_origin = zip(*results) + + return np.asarray(index, dtype=int), list(data), list(data_origin) + + def __len__(self): + return self.index.shape[0] + + def preprocess_online(self, ind): + batch = self.data[ind] + ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_input, + neighbors_future_data, + plan_data, + ) = batch + + if "map_name" not in plan_data: + scene, _, _ = self.data_origin[ind] + plan_data["map_name"] = str(scene.map_name) + plan_data["scene_offset"] = torch.Tensor( + [scene.x_min, scene.y_min], device="cpu" + ).float() + if "most_relevant_nearby_lane_tokens" not in plan_data: + plan_data["most_relevant_nearby_lane_tokens"] = None + + return ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_input, + neighbors_future_data, + plan_data, + ) + + def __getitem__(self, i): + # https://pytorch.org/docs/master/data.html#torch.utils.data.distributed.DistributedSampler + # https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662 + + # return self.data[self.index[i]] + return self.preprocess_online(self.index[i]) + + def filter(self, filter_fn, verbose=False): + tstart = time.time() + len_start = len(self) + + if filter_fn is not None: + self.index = np.fromiter( + (i for i in self.index if filter_fn(self.data[i])), + dtype=self.index.dtype, + ) + + if verbose: + print( + "Filter: kept %d/%d (%.1f%%) of samples. Filtering took %.1fs." + % ( + len(self), + len_start, + len(self) / len_start * 100.0, + time.time() - tstart, + ) + ) diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/homography_warper.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/homography_warper.py new file mode 100644 index 0000000..885ab5f --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/homography_warper.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + + +pi = torch.tensor(3.14159265358979323846) + + +def deg2rad(tensor: torch.Tensor) -> torch.Tensor: + r"""Function that converts angles from degrees to radians. + Args: + tensor (torch.Tensor): Tensor of arbitrary shape. + Returns: + torch.Tensor: tensor with same shape as input. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(tensor))) + + return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. + + +def angle_to_rotation_matrix(angle: torch.Tensor) -> torch.Tensor: + """ + Creates a rotation matrix out of angles in degrees + Args: + angle: (torch.Tensor): tensor of angles in degrees, any shape. + Returns: + torch.Tensor: tensor of *x2x2 rotation matrices. + Shape: + - Input: :math:`(*)` + - Output: :math:`(*, 2, 2)` + Example: + >>> input = torch.rand(1, 3) # Nx3 + >>> output = kornia.angle_to_rotation_matrix(input) # Nx3x2x2 + """ + ang_rad = deg2rad(angle) + cos_a: torch.Tensor = torch.cos(ang_rad) + sin_a: torch.Tensor = torch.sin(ang_rad) + return torch.stack([cos_a, sin_a, -sin_a, cos_a], dim=-1).view(*angle.shape, 2, 2) + + +def get_rotation_matrix2d( + center: torch.Tensor, + angle: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + r"""Calculates an affine matrix of 2D rotation. + The function calculates the following matrix: + .. math:: + \begin{bmatrix} + \alpha & \beta & (1 - \alpha) \cdot \text{x} + - \beta \cdot \text{y} \\ + -\beta & \alpha & \beta \cdot \text{x} + + (1 - \alpha) \cdot \text{y} + \end{bmatrix} + where + .. math:: + \alpha = \text{scale} \cdot cos(\text{radian}) \\ + \beta = \text{scale} \cdot sin(\text{radian}) + The transformation maps the rotation center to itself + If this is not the target, adjust the shift. + Args: + center (Tensor): center of the rotation in the source image. + angle (Tensor): rotation radian in degrees. Positive values mean + counter-clockwise rotation (the coordinate origin is assumed to + be the top-left corner). + scale (Tensor): isotropic scale factor. + Returns: + Tensor: the affine matrix of 2D rotation. + Shape: + - Input: :math:`(B, 2)`, :math:`(B)` and :math:`(B)` + - Output: :math:`(B, 2, 3)` + Example: + >>> center = torch.zeros(1, 2) + >>> scale = torch.ones(1) + >>> radian = 45. * torch.ones(1) + >>> M = kornia.get_rotation_matrix2d(center, radian, scale) + tensor([[[ 0.7071, 0.7071, 0.0000], + [-0.7071, 0.7071, 0.0000]]]) + """ + if not torch.is_tensor(center): + raise TypeError("Input center type is not a torch.Tensor. Got {}" + .format(type(center))) + if not torch.is_tensor(angle): + raise TypeError("Input radian type is not a torch.Tensor. Got {}" + .format(type(angle))) + if not torch.is_tensor(scale): + raise TypeError("Input scale type is not a torch.Tensor. Got {}" + .format(type(scale))) + if not (len(center.shape) == 2 and center.shape[1] == 2): + raise ValueError("Input center must be a Bx2 tensor. Got {}" + .format(center.shape)) + if not len(angle.shape) == 1: + raise ValueError("Input radian must be a B tensor. Got {}" + .format(angle.shape)) + if not len(scale.shape) == 1: + raise ValueError("Input scale must be a B tensor. Got {}" + .format(scale.shape)) + if not (center.shape[0] == angle.shape[0] == scale.shape[0]): + raise ValueError("Inputs must have same batch size dimension. Got {}" + .format(center.shape, angle.shape, scale.shape)) + # convert radian and apply scale + scaled_rotation: torch.Tensor = angle_to_rotation_matrix(angle) * scale.view(-1, 1, 1) + alpha: torch.Tensor = scaled_rotation[:, 0, 0] + beta: torch.Tensor = scaled_rotation[:, 0, 1] + + # unpack the center to x, y coordinates + x: torch.Tensor = center[..., 0] + y: torch.Tensor = center[..., 1] + + # create output tensor + batch_size: int = center.shape[0] + M: torch.Tensor = torch.zeros( + batch_size, 2, 3, device=center.device, dtype=center.dtype) + M[..., 0:2, 0:2] = scaled_rotation + M[..., 0, 2] = (torch.tensor(1.) - alpha) * x - beta * y + M[..., 1, 2] = beta * x + (torch.tensor(1.) - alpha) * y + return M + +def convert_points_to_homogeneous(points: torch.Tensor) -> torch.Tensor: + r"""Function that converts points from Euclidean to homogeneous space. + Examples:: + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> output = kornia.convert_points_to_homogeneous(input) # BxNx4 + """ + if not isinstance(points, torch.Tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(points))) + if len(points.shape) < 2: + raise ValueError("Input must be at least a 2D tensor. Got {}".format( + points.shape)) + + return torch.nn.functional.pad(points, [0, 1], "constant", 1.0) + + +def convert_points_from_homogeneous( + points: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + r"""Function that converts points from homogeneous to Euclidean space. + Examples:: + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> output = kornia.convert_points_from_homogeneous(input) # BxNx2 + """ + if not isinstance(points, torch.Tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(points))) + + if len(points.shape) < 2: + raise ValueError("Input must be at least a 2D tensor. Got {}".format( + points.shape)) + + # we check for points at infinity + z_vec: torch.Tensor = points[..., -1:] + + # set the results of division by zeror/near-zero to 1.0 + # follow the convention of opencv: + # https://github.com/opencv/opencv/pull/14411/files + mask: torch.Tensor = torch.abs(z_vec) > eps + scale: torch.Tensor = torch.ones_like(z_vec).masked_scatter_( + mask, torch.tensor(1.0).to(points.device) / z_vec[mask]) + + return scale * points[..., :-1] + +def transform_points(trans_01: torch.Tensor, + points_1: torch.Tensor) -> torch.Tensor: + r"""Function that applies transformations to a set of points. + Args: + trans_01 (torch.Tensor): tensor for transformations of shape + :math:`(B, D+1, D+1)`. + points_1 (torch.Tensor): tensor of points of shape :math:`(B, N, D)`. + Returns: + torch.Tensor: tensor of N-dimensional points. + Shape: + - Output: :math:`(B, N, D)` + Examples: + >>> points_1 = torch.rand(2, 4, 3) # BxNx3 + >>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4 + >>> points_0 = kornia.transform_points(trans_01, points_1) # BxNx3 + """ + if not torch.is_tensor(trans_01) or not torch.is_tensor(points_1): + raise TypeError("Input type is not a torch.Tensor") + if not trans_01.device == points_1.device: + raise TypeError("Tensor must be in the same device") + if not trans_01.shape[0] == points_1.shape[0] and trans_01.shape[0] != 1: + raise ValueError("Input batch size must be the same for both tensors or 1") + if not trans_01.shape[-1] == (points_1.shape[-1] + 1): + raise ValueError("Last input dimensions must differe by one unit") + # to homogeneous + points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1 + # transform coordinates + points_0_h = torch.matmul( + trans_01.unsqueeze(1), points_1_h.unsqueeze(-1)) + points_0_h = torch.squeeze(points_0_h, dim=-1) + # to euclidean + points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD + return points_0 + + +def multi_linspace(a, b, num, endpoint=True, device='cpu', dtype=torch.float): + """This function is just like np.linspace, but will create linearly + spaced vectors from a start to end vector. + Inputs: + a - Start vector. + b - End vector. + num - Number of samples to generate. Default is 50. Must be above 0. + endpoint - If True, b is the last sample. + Otherwise, it is not included. Default is True. + """ + + return a[..., None] + (b-a)[..., None]/(num-endpoint) * torch.arange(num, device=device, dtype=dtype) + + +def create_batched_meshgrid( + x_min: torch.Tensor, + y_min: torch.Tensor, + x_max: torch.Tensor, + y_max: torch.Tensor, + height: int, + width: int, + device: Optional[torch.device] = torch.device('cpu')) -> torch.Tensor: + """Generates a coordinate grid for an image. + When the flag `normalized_coordinates` is set to True, the grid is + normalized to be in the range [-1,1] to be consistent with the pytorch + function grid_sample. + http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample + Args: + height (int): the image height (rows). + width (int): the image width (cols). + normalized_coordinates (Optional[bool]): whether to normalize + coordinates in the range [-1, 1] in order to be consistent with the + PyTorch function grid_sample. + Return: + torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`. + """ + # generate coordinates + xs = multi_linspace(x_min, x_max, width, device=device, dtype=torch.float) + ys = multi_linspace(y_min, y_max, height, device=device, dtype=torch.float) + + # generate grid by stacking coordinates + bs = x_min.shape[0] + batched_grid_i_list = list() + for i in range(bs): + batched_grid_i_list.append(torch.stack(torch.meshgrid([xs[i], ys[i]])).transpose(1, 2)) # 2xHxW + batched_grid: torch.Tensor = torch.stack(batched_grid_i_list, dim=0) + return batched_grid.permute(0, 2, 3, 1) # BxHxWx2 + + +def homography_warp(patch_src: torch.Tensor, + centers: torch.Tensor, + dst_homo_src: torch.Tensor, + dsize: Tuple[int, int], + mode: str = 'bilinear', + padding_mode: str = 'zeros') -> torch.Tensor: + r"""Function that warps image patchs or tensors by homographies. + See :class:`~kornia.geometry.warp.HomographyWarper` for details. + Args: + patch_src (torch.Tensor): The image or tensor to warp. Should be from + source of shape :math:`(N, C, H, W)`. + dst_homo_src (torch.Tensor): The homography or stack of homographies + from source to destination of shape + :math:`(N, 3, 3)`. + dsize (Tuple[int, int]): The height and width of the image to warp. + mode (str): interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear'. + padding_mode (str): padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros'. + Return: + torch.Tensor: Patch sampled at locations from source to destination. + Example: + >>> input = torch.rand(1, 3, 32, 32) + >>> homography = torch.eye(3).view(1, 3, 3) + >>> output = kornia.homography_warp(input, homography, (32, 32)) + """ + + out_height, out_width = dsize + image_height, image_width = patch_src.shape[-2:] + x_min = 2. * (centers[..., 0] - out_width/2) / image_width - 1. + y_min = 2. * (centers[..., 1] - out_height/2) / image_height - 1. + x_max = 2. * (centers[..., 0] + out_width/2) / image_width - 1. + y_max = 2. * (centers[..., 1] + out_height/2) / image_height - 1. + warper = HomographyWarper(x_min, y_min, x_max, y_max, out_height, out_width, mode, padding_mode) + return warper(patch_src, dst_homo_src) + + +def normal_transform_pixel(height, width): + + tr_mat = torch.Tensor([[1.0, 0.0, -1.0], + [0.0, 1.0, -1.0], + [0.0, 0.0, 1.0]]) # 1x3x3 + + tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / (width - 1.0) + tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / (height - 1.0) + + tr_mat = tr_mat.unsqueeze(0) + + return tr_mat + + +def src_norm_to_dst_norm(dst_pix_trans_src_pix: torch.Tensor, + dsize_src: Tuple[int, int], dsize_dst: Tuple[int, int]) -> torch.Tensor: + # source and destination sizes + src_h, src_w = dsize_src + dst_h, dst_w = dsize_dst + # the devices and types + device: torch.device = dst_pix_trans_src_pix.device + dtype: torch.dtype = dst_pix_trans_src_pix.dtype + # compute the transformation pixel/norm for src/dst + src_norm_trans_src_pix: torch.Tensor = normal_transform_pixel( + src_h, src_w).to(device, dtype) + src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix) + dst_norm_trans_dst_pix: torch.Tensor = normal_transform_pixel( + dst_h, dst_w).to(device, dtype) + # compute chain transformations + dst_norm_trans_src_norm: torch.Tensor = ( + dst_norm_trans_dst_pix @ (dst_pix_trans_src_pix @ src_pix_trans_src_norm) + ) + return dst_norm_trans_src_norm + + +def transform_warp_impl(src: torch.Tensor, centers: torch.Tensor, dst_pix_trans_src_pix: torch.Tensor, + dsize_src: Tuple[int, int], dsize_dst: Tuple[int, int], + grid_mode: str, padding_mode: str) -> torch.Tensor: + """Compute the transform in normalized cooridnates and perform the warping. + """ + dst_norm_trans_src_norm: torch.Tensor = src_norm_to_dst_norm( + dst_pix_trans_src_pix, dsize_src, dsize_src) + + src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm) + return homography_warp(src, centers, src_norm_trans_dst_norm, dsize_dst, grid_mode, padding_mode) + + +class HomographyWarper(nn.Module): + r"""Warps image patches or tensors by homographies. + .. math:: + X_{dst} = H_{src}^{\{dst\}} * X_{src} + Args: + height (int): The height of the image to warp. + width (int): The width of the image to warp. + mode (str): interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear'. + padding_mode (str): padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros'. + """ + + def __init__( + self, + x_min: torch.Tensor, + y_min: torch.Tensor, + x_max: torch.Tensor, + y_max: torch.Tensor, + height: int, + width: int, + mode: str = 'bilinear', + padding_mode: str = 'zeros') -> None: + super(HomographyWarper, self).__init__() + self.width: int = width + self.height: int = height + self.mode: str = mode + self.padding_mode: str = padding_mode + + # create base grid to compute the flow + self.grid: torch.Tensor = create_batched_meshgrid(x_min, y_min, x_max, y_max, height, width) + + def warp_grid(self, dst_homo_src: torch.Tensor) -> torch.Tensor: + r"""Computes the grid to warp the coordinates grid by an homography. + Args: + dst_homo_src (torch.Tensor): Homography or homographies (stacked) to + transform all points in the grid. Shape of the + homography has to be :math:`(N, 3, 3)`. + Returns: + torch.Tensor: the transformed grid of shape :math:`(N, H, W, 2)`. + """ + batch_size: int = dst_homo_src.shape[0] + device: torch.device = dst_homo_src.device + dtype: torch.dtype = dst_homo_src.dtype + # expand grid to match the input batch size + grid: torch.Tensor = self.grid + if len(dst_homo_src.shape) == 3: # local homography case + dst_homo_src = dst_homo_src.view(batch_size, 1, 3, 3) # NxHxWx3x3 + # perform the actual grid transformation, + # the grid is copied to input device and casted to the same type + flow: torch.Tensor = transform_points( + dst_homo_src, grid.to(device).to(dtype)) # NxHxWx2 + return flow.view(batch_size, self.height, self.width, 2) # NxHxWx2 + + def forward( # type: ignore + self, + patch_src: torch.Tensor, + dst_homo_src: torch.Tensor) -> torch.Tensor: + r"""Warps an image or tensor from source into reference frame. + Args: + patch_src (torch.Tensor): The image or tensor to warp. + Should be from source. + dst_homo_src (torch.Tensor): The homography or stack of homographies + from source to destination. The homography assumes normalized + coordinates [-1, 1]. + Return: + torch.Tensor: Patch sampled at locations from source to destination. + Shape: + - Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)` + - Output: :math:`(N, C, H, W)` + Example: + >>> input = torch.rand(1, 3, 32, 32) + >>> homography = torch.eye(3).view(1, 3, 3) + >>> warper = kornia.HomographyWarper(32, 32) + >>> output = warper(input, homography) # NxCxHxW + """ + if not dst_homo_src.device == patch_src.device: + raise TypeError("Patch and homography must be on the same device. \ + Got patch.device: {} dst_H_src.device: {}." + .format(patch_src.device, dst_homo_src.device)) + + return F.grid_sample(patch_src, self.warp_grid(dst_homo_src), # type: ignore + mode=self.mode, padding_mode=self.padding_mode, align_corners=True) + + +def warp_affine_crop(src: torch.Tensor, centers: torch.Tensor, M: torch.Tensor, + dsize: Tuple[int, int], flags: str = 'bilinear', + padding_mode: str = 'zeros') -> torch.Tensor: + r"""Applies an affine transformation to a tensor. + + The function warp_affine transforms the source tensor using + the specified matrix: + + .. math:: + \text{dst}(x, y) = \text{src} \left( M_{11} x + M_{12} y + M_{13} , + M_{21} x + M_{22} y + M_{23} \right ) + + Args: + src (torch.Tensor): input tensor of shape :math:`(B, C, H, W)`. + M (torch.Tensor): affine transformation of shape :math:`(B, 2, 3)`. + dsize (Tuple[int, int]): size of the output image (height, width). + mode (str): interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear'. + padding_mode (str): padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros'. + + Returns: + torch.Tensor: the warped tensor. + + Shape: + - Output: :math:`(B, C, H, W)` + + .. note:: + See a working example `here `__. + """ + if not torch.is_tensor(src): + raise TypeError("Input src type is not a torch.Tensor. Got {}" + .format(type(src))) + + if not torch.is_tensor(M): + raise TypeError("Input M type is not a torch.Tensor. Got {}" + .format(type(M))) + + if not len(src.shape) == 4: + raise ValueError("Input src must be a BxCxHxW tensor. Got {}" + .format(src.shape)) + + if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)): + raise ValueError("Input M must be a Bx2x3 tensor. Got {}" + .format(src.shape)) + + # we generate a 3x3 transformation matrix from 2x3 affine + M_3x3: torch.Tensor = F.pad(M, [0, 0, 0, 1, 0, 0], + mode="constant", value=0) + M_3x3[:, 2, 2] += 1.0 + + # launches the warper + h, w = src.shape[-2:] + return transform_warp_impl(src, centers, M_3x3, (h, w), dsize, flags, padding_mode) \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py new file mode 100644 index 0000000..a110435 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py @@ -0,0 +1,757 @@ +import torch +import numpy as np +import collections.abc +from torch.utils.data._utils.collate import default_collate +import dill + +from nuscenes.map_expansion import arcline_path_utils +from trajdata.utils.arr_utils import batch_proj + +container_abcs = collections.abc + + +# Wrapper around dict to identify batchable dict data. +class batchable_dict(dict): + pass + + +class batchable_list(list): + pass + + +class batchable_nonuniform_tensor(torch.Tensor): + pass + + +def np_unstack(a, axis=0): + return np.moveaxis(a, axis, 0) + + +def restore(data): + """ + In case we dilled some structures to share between multiple process this function will restore them. + If the data input are not bytes we assume it was not dilled in the first place + + :param data: Possibly dilled data structure + :return: Un-dilled data structure + """ + if type(data) is bytes: + return dill.loads(data) + return data + + +def collate(batch): + if len(batch) == 0: + return batch + elem = batch[0] + if elem is None: + return None + elif ( + isinstance(elem, str) + or isinstance(elem, batchable_list) + or isinstance(elem, batchable_nonuniform_tensor) + ): + return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch + elif isinstance(elem, container_abcs.Sequence): + if ( + len(elem) == 4 + ): # We assume those are the maps, map points, headings and patch_size + scene_map, scene_pts, heading_angle, patch_size = zip(*batch) + if heading_angle[0] is None: + heading_angle = None + else: + heading_angle = torch.Tensor(heading_angle) + map = scene_map[0].get_cropped_maps_from_scene_map_batch( + scene_map, + scene_pts=torch.Tensor(scene_pts), + patch_size=patch_size[0], + rotation=heading_angle, + ) + return map + transposed = zip(*batch) + return [collate(samples) for samples in transposed] + elif isinstance(elem, batchable_dict): + # We dill the dictionary for the same reason as the neighbors structure (see below). + # Unlike for neighbors where we keep a list, here we collate elements recursively + data_dict = {key: collate([d[key] for d in batch]) for key in elem} + return ( + dill.dumps(data_dict) if torch.utils.data.get_worker_info() else data_dict + ) + elif isinstance(elem, container_abcs.Mapping): + # We have to dill the neighbors structures. Otherwise each tensor is put into + # shared memory separately -> slow, file pointer overhead + # we only do this in multiprocessing + neighbor_dict = {key: [d[key] for d in batch] for key in elem} + return ( + dill.dumps(neighbor_dict) + if torch.utils.data.get_worker_info() + else neighbor_dict + ) + try: + return default_collate(batch) + except RuntimeError: + # This happens when tensors are not of the same shape. + return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch + + +def get_relevant_lanes_np(nusc_map, scene_offset, position, yaw, vel, dt): + # - max deceleration + # - [constant speed] + # - max acceleration + # Normally we should do this along the current lane -- but if that's hard to + # determin reliably, an alternative is to consider both ego-centric and lane centric ways. + + # Simple alternative: use goal state for lane selection. + # This is biased e.g. when stopped at a traffic light, and we will miss on the + # lanes across the intersection when gt was stopped, and we will miss on the current + # nearby lanes when gt accelerates rapidly. + + interp_vel = max(abs(vel), 0.2) + x = position[0] + scene_offset[0] + y = position[1] + scene_offset[1] + state = np.hstack((position, yaw)) + # t = time.time() + lanes = nusc_map.get_records_in_radius(x, y, 20.0, ["lane", "lane_connector"]) + lanes = lanes["lane"] + lanes["lane_connector"] + # t_near = time.time() - t + # t = time.time() + relevant_lanes = list() + relevant_lane_tokens = list() + relevant_lane_arclines = list() + for lane in lanes: + lane_arcline = nusc_map.get_arcline_path(lane) + poses = arcline_path_utils.discretize_lane( + lane_arcline, resolution_meters=interp_vel * dt + ) + poses = np.array(poses) + poses[:, 0:2] -= scene_offset + delta_x, delta_y, dpsi = batch_proj(state, poses) + if abs(dpsi[0]) < np.pi / 4 and np.min(np.abs(delta_y)) < 4.5: + relevant_lanes.append(poses) + relevant_lane_tokens.append(str(lane)) + relevant_lane_arclines.append(lane_arcline) + # print (f"Time near: {t_near:.4f} filt: {time.time()-t:.4f}") # Time near: 0.1969 filt: 0.016 + return relevant_lanes, relevant_lane_tokens, relevant_lane_arclines + + +def get_relative_robot_traj(env, state, node_traj, robot_traj, node_type, robot_type): + # Make Robot State relative to node + _, std = env.get_standardize_params(state[robot_type], node_type=robot_type) + std[0:2] = env.attention_radius[(node_type, robot_type)] + robot_traj_st = env.standardize( + robot_traj, state[robot_type], node_type=robot_type, mean=node_traj, std=std + ) + robot_traj_st_t = torch.tensor(robot_traj_st, dtype=torch.float) + + return robot_traj_st_t + + +def pred_state_to_plan_state(pred_state): + """ + input: x, y, vx, vy, ax, ay, heading, delta_heading + output: x, y, heading, v, acc, delta_heading + """ + x, y, vx, vy, ax, ay, h, dh = np_unstack(pred_state, -1) + v = np.linalg.norm(np.stack((vx, vy), axis=-1), axis=-1, keepdims=False) + a = np.linalg.norm(np.stack((ax, ay), axis=-1), axis=-1, keepdims=False) + # a = np.divide(ax * vx + ay * vy, v, out=np.zeros_like(ax), where=(v > 1.)) + plan_state = np.stack([x, y, h, v, a, dh], axis=-1) + + # assert np.isclose(pred_state, plan_state_to_pred_state(plan_state)).all() # accelerations mismatch, the calculation of ax and ay cannot be recovered from a_norm and heading + return plan_state + + +def plan_state_to_pred_state(plan_state): + """ + input: x, y, h, v, a, delta_heading + output: x, y, vx, vy, ax, ay, heading, delta_heading + """ + x, y, h, v, a, dh = np_unstack(plan_state, -1) + # Assume vehicle can only move forwards + vx = v * np.cos(h) + vy = v * np.sin(h) + ax = a * np.cos(h) + ay = a * np.sin(h) + pred_state = np.stack([x, y, vx, vy, ax, ay, h, dh], axis=-1) + return pred_state + + +def get_node_closest_to_robot(scene, t, node_type=None, nodes=None): + """ + + :param scene: Scene + :param t: Timestep in scene + """ + get_pose = lambda n: n.get(np.array([t, t]), {"position": ["x", "y"]}, padding=0.0) + node_dist = lambda a, b: np.linalg.norm(get_pose(a) - get_pose(b)) + + robot_node = scene.robot + closest_node = None + closest_dist = None + + if nodes is None: + nodes = scene.nodes + + for node in nodes: + if node == robot_node: + continue + if node_type is not None and node.type != node_type: + continue + dist = node_dist(node, robot_node) + if closest_dist is None or dist < closest_dist: + closest_dist = dist + closest_node = node + + return closest_node + + +def get_node_timestep_data( + env, + scene, + t, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + scene_graph=None, + is_closed_loop=False, + closed_loop_ego_hist=None, +): + """ + Pre-processes the data for a single batch element: node state over time for a specific time in a specific scene + as well as the neighbour data for it. + + :param env: Environment + :param scene: Scene + :param t: Timestep in scene + :param node: Node + :param state: Specification of the node state + :param pred_state: Specification of the prediction state + :param edge_types: List of all Edge Types for which neighbours are pre-processed + :param max_ht: Maximum history timesteps + :param max_ft: Maximum future timesteps (prediction horizon) + :param hyperparams: Model hyperparameters + :param scene_graph: If scene graph was already computed for this scene and time you can pass it here + :return: Batch Element + """ + + # Node + timestep_range_x = np.array([t - max_ht, t]) + timestep_range_y = np.array([t + 1, t + max_ft]) + timestep_range_plan = np.array([t, t + max_ft]) + + plan_vehicle_state_dict = { + "position": ["x", "y"], + "heading": ["°"], + "velocity": ["norm"], + } + if hyperparams["dataset_version"] == "v2": + plan_vehicle_features_dict = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_dh": [ + "t" + str(i) for i in range(max_ft + 1) + ], # future fitted controls + "control_a": ["t" + str(i) for i in range(max_ft + 1)], + } + else: + plan_vehicle_features_dict = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_traj_dh": [ + "t" + str(i) for i in range(max_ft + 1) + ], # future fitted controls + "control_traj_a": ["t" + str(i) for i in range(max_ft + 1)], + # 'control_goal_dh': ["t"+str(i) for i in range(max_ft+1)], # future fitted controls + # 'control_goal_a': ["t"+str(i) for i in range(max_ft+1)], + } + plan_vehicle_features_dict_old = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_dh": [i for i in range(max_ft + 1)], # future fitted controls + "control_a": [i for i in range(max_ft + 1)], + } + plan_pedestrian_state_dict = { + "position": ["x", "y"], + } + + # Filter fields not in data + state = { + nk: {k: v for k, v in ndict.items() if k != "augment"} + for nk, ndict in state.items() + } + + x = node.get(timestep_range_x, state[node.type]) + y = node.get(timestep_range_y, pred_state[node.type]) + first_history_index = (max_ht - node.history_points_at(t)).clip(0) + + # Origin + x_origin = np.array(x)[-1] + + _, std = env.get_standardize_params(state[node.type], node.type) + std[0:2] = env.attention_radius[(node.type, node.type)] + rel_state = np.zeros_like(x[0]) + rel_state[0:2] = x_origin[0:2] + x_st = env.standardize(x, state[node.type], node.type, mean=rel_state, std=std) + if ( + list(pred_state[node.type].keys())[0] == "position" + ): # If we predict position we do it relative to current pos + y_st = env.standardize(y, pred_state[node.type], node.type, mean=rel_state[0:2]) + else: + y_st = env.standardize(y, pred_state[node.type], node.type) + + x_t = torch.tensor(x, dtype=torch.float) + y_t = torch.tensor(y, dtype=torch.float) + x_st_t = torch.tensor(x_st, dtype=torch.float) + y_st_t = torch.tensor(y_st, dtype=torch.float) + + scene_offset_np = np.array([scene.x_min, scene.y_min], dtype=np.float32) + + # Neighbors + neighbors_data_st = None + neighbors_edge_value = None + neighbors_future_data = None + plan_data = None + if hyperparams["edge_encoding"]: + # Scene Graph + scene_graph = ( + scene.get_scene_graph( + t, + env.attention_radius, + hyperparams["edge_addition_filter"], + hyperparams["edge_removal_filter"], + ) + if scene_graph is None + else scene_graph + ) + + neighbors_data_not_st = dict() # closed loop + logged_robot_data = None # closed loop + neighbors_data_st = dict() + neighbors_edge_value = dict() + neighbors_future_data = dict() + is_neighbor_parked = dict() + for edge_type in edge_types: + neighbors_data_not_st[edge_type] = list() + neighbors_data_st[edge_type] = list() + neighbors_future_data[edge_type] = list() + is_neighbor_parked[edge_type] = list() + robot_neighbor = -1 + + # We get all nodes which are connected to the current node for the current timestep + connected_nodes = scene_graph.get_neighbors(node, edge_type[1]) + + if hyperparams["dynamic_edges"] == "yes": + # We get the edge masks for the current node at the current timestep + edge_masks = torch.tensor( + scene_graph.get_edge_scaling(node), dtype=torch.float + ) + neighbors_edge_value[edge_type] = edge_masks + + for n_i, connected_node in enumerate(connected_nodes): + neighbor_state_np = connected_node.get( + timestep_range_x, state[connected_node.type], padding=0.0 + ) + + # Closed loop, replace ego trajectory + if is_closed_loop and connected_node.is_robot: + # assert logged_robot_data is None, "We can only replace robot trajectory once, and we already did that." + # We expect to get here twice, once for (VEHICLE, ROBOT) and once for (PEDESTRIAN, ROBOT) + logged_robot_data = torch.tensor( + neighbor_state_np, dtype=torch.float + ) + if closed_loop_ego_hist is not None: + assert ( + neighbor_state_np.shape[0] <= closed_loop_ego_hist.shape[0] + ) + assert ( + neighbor_state_np.shape[1] == closed_loop_ego_hist.shape[1] + ) + neighbor_state_np = closed_loop_ego_hist.copy() + + # Make State relative to node where neighbor and node have same state + _, std = env.get_standardize_params( + state[connected_node.type], node_type=connected_node.type + ) + std[0:2] = env.attention_radius[edge_type] + equal_dims = np.min((neighbor_state_np.shape[-1], x.shape[-1])) + rel_state = np.zeros_like(neighbor_state_np) + rel_state[:, ..., :equal_dims] = x_origin[..., :equal_dims] + neighbor_state_np_st = env.standardize( + neighbor_state_np, + state[connected_node.type], + node_type=connected_node.type, + mean=rel_state, + std=std, + ) + + neighbor_state = torch.tensor(neighbor_state_np_st, dtype=torch.float) + neighbors_data_st[edge_type].append(neighbor_state) + if is_closed_loop: + neighbors_data_not_st[edge_type].append( + torch.tensor(neighbor_state_np, dtype=torch.float) + ) + + # Add future states for all neighbors. Standardize with same origin and std. + if edge_type[1] == "VEHICLE": + assert connected_node.type == "VEHICLE" + + try: + neighbor_future_features_np = np.concatenate( + ( + # x, y, orient, vel + connected_node.get( + timestep_range_plan, + plan_vehicle_state_dict, + padding=np.nan, + ), + # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph + connected_node.get( + timestep_range_plan, + plan_vehicle_features_dict, + padding=np.nan, + ), + ), + axis=-1, + ) + except KeyError: + neighbor_future_features_np = np.concatenate( + ( + # x, y, orient, vel + connected_node.get( + timestep_range_plan, + plan_vehicle_state_dict, + padding=np.nan, + ), + # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph + connected_node.get( + timestep_range_plan, + plan_vehicle_features_dict_old, + padding=np.nan, + ), + ), + axis=-1, + ) + + if ( + is_closed_loop + and closed_loop_ego_hist is not None + and connected_node.is_robot + ): + neighbor_future_features_np[0, :6] = pred_state_to_plan_state( + closed_loop_ego_hist[-1] + ) + + # Add lane points + if hyperparams["dataset_version"] == "v2": + pass + else: + lane_ref_points = connected_node.get_lane_points( + timestep_range_plan, padding=np.nan, num_lane_points=16 + ) + neighbor_future_features_np = np.concatenate( + ( + neighbor_future_features_np, + lane_ref_points.reshape( + (lane_ref_points.shape[0], 16 * 3) + ), + ), + axis=1, + ) + + # # Assert accelearation norm is correct + # acc = connected_node.get(timestep_range_plan, {'acceleration': ['x', 'y']}) + # acc = np.linalg.norm(acc, axis=-1) + # if not np.logical_or(np.isclose(acc, neighbor_future_features_np[:, 5]), np.isnan(acc)).all(): + # print (acc, neighbor_future_features_np[:, 5]) + # pass + # dheading = connected_node.get(timestep_range_plan, {'heading': ['d°']})[:, 0] + # if not np.logical_or(np.isclose(dheading, neighbor_future_features_np[:, 4]), np.isnan(dheading)).all(): + # print(dheading, neighbor_future_features_np[:, 4]) + # pass + + # if np.isclose(neighbor_future_features_np[:, :2], 0.).any(): + # print("issue") + # pass + + # lane_dist = np.linalg.norm(neighbor_future_features_np[:, :2]-neighbor_future_features_np[:, 6:8], axis=-1) + # if np.any(np.nan_to_num(lane_dist, 0.) >= 50.): + # print (lane_dist) + # pass + + # is_robot_vect = np.ones((neighbor_future_features_np.shape[0], 1)) * float(connected_node.is_robot) + # neighbor_future_features_np = np.concatenate((neighbor_future_features_np, is_robot_vect), axis=-1) + # is_robot = torch.tensor([float(connected_node.is_robot)], dtype=torch.float) + if connected_node.is_robot: + robot_neighbor = torch.Tensor([n_i]).int() + + # Check if car is parked, i.e., it hasnt moved more more than 1m + xy_first = connected_node.data[0, {"position": ["x", "y"]}].astype( + np.float32 + ) + xy_last = connected_node.data[ + connected_node.last_timestep - connected_node.first_timestep, + {"position": ["x", "y"]}, + ].astype(np.float32) + is_neighbor_parked[edge_type].append( + np.square(xy_first - xy_last).sum(0) < 1.0**2 + ) + + elif edge_type[1] == "PEDESTRIAN": + neighbor_future_features_np = connected_node.get( + timestep_range_plan, plan_pedestrian_state_dict, padding=0.0 + ) # x, y | lane x y heading | projected x y + else: + raise ValueError("Unknown type {}".format(edge_type[1])) + neighbor_future_features = torch.tensor( + neighbor_future_features_np, dtype=torch.float + ) + neighbors_future_data[edge_type].append(neighbor_future_features) + + assert ( + hyperparams["dt"] == scene.dt + ) # we will rely on this hyperparam for delta_t, make sure its correct + + # Find closest neighbor for planning + # if 'planner' in hyperparams and hyperparams['planner'] and edge_type[1] == 'VEHICLE': + if edge_type[1] == "VEHICLE": + vehicle_future_f = neighbors_future_data[(node.type, "VEHICLE")] + is_parked = is_neighbor_parked[(node.type, "VEHICLE")] + + # Find most relevant agent + dists = [] + inds = [] + for n_i in range(len(vehicle_future_f)): + # Filter incomplete future states, controls or missing lanes + if torch.isnan(vehicle_future_f[n_i][:, : (4 + 2 + 3)]).any(): + continue + + # Filter parked cars for v7 only. + if hyperparams["dataset_version"] not in ["v7"]: + if is_parked[n_i]: + continue + + dist = torch.square( + y_t - vehicle_future_f[n_i][1:, :2] + ) # [1:] exclude current state for vehicle future + dist = dist.sum(dim=-1).amin( + dim=-1 + ) # sum over states, min over time + inds.append(n_i) + dists.append(dist) + + if dists: + plan_i = inds[ + torch.argmin(torch.stack(dists)) + ] # neighbor that gets closest to current node + else: + # No neighbors or all futures are incomplete + plan_i = -1 + + # Robot index, filter incomplete futures, controls or missing lanes + if ( + robot_neighbor >= 0 + and torch.isnan( + vehicle_future_f[robot_neighbor][:, : (4 + 2 + 3)] + ).any() + ): + robot_i = -1 + else: + robot_i = robot_neighbor + + # Pretend robot is the most relevant agent for closed loop + if is_closed_loop: + plan_i = robot_i + + # Get nearby lanes for most_relevant neighbor (used for trajectroy fan planner) + if plan_i >= 0: + nusc_map = nusc_maps[scene.map_name] + + # Relevant lanes that are near the goal state + x_goal_np = vehicle_future_f[plan_i][-1, :4].numpy() # x,y,h,v + pos_xy_goal, yaw_goal, vel_goal = np.split( + x_goal_np, (2, 3), axis=-1 + ) + ( + relevant_lanes, + relevant_lane_tokens, + relevant_lane_arclines, + ) = get_relevant_lanes_np( + nusc_map, + scene_offset_np, + pos_xy_goal, + yaw_goal, + vel_goal, + hyperparams["dt"], + ) + else: + relevant_lanes = [] + relevant_lane_tokens = [] + + # plan_data = torch.Tensor([float(plan_i), float(robot_neighbor)]) + plan_data = batchable_dict( + most_relevant_idx=torch.Tensor([int(plan_i)]).int().squeeze(0), + robot_idx=torch.Tensor([int(robot_i)]).int().squeeze(0), + most_relevant_nearby_lanes=batchable_list( + [torch.from_numpy(pts).float() for pts in relevant_lanes] + ), + most_relevant_nearby_lane_tokens=batchable_list( + relevant_lane_tokens + ), + map_name=str(scene.map_name), + scene_offset=torch.from_numpy(scene_offset_np), + ) + else: + assert hyperparams["planner"] in ["", "none"] + plan_data = None + + # Robot + robot_traj_st_t = None + timestep_range_r = np.array([t, t + max_ft]) + if hyperparams["incl_robot_node"]: + x_node = node.get(timestep_range_r, state[node.type]) + if scene.non_aug_scene is not None: + robot = scene.get_node_by_id(scene.non_aug_scene.robot.id) + else: + robot = scene.robot + robot_type = robot.type + robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan) + robot_traj_st_t = get_relative_robot_traj( + env, state, x_node, robot_traj, node.type, robot_type + ) + robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0 + + # Map + map_tuple = None + if hyperparams["use_map_encoding"]: + if node.type in hyperparams["map_encoder"]: + if node.non_aug_node is not None: + x = node.non_aug_node.get(np.array([t]), state[node.type]) + me_hyp = hyperparams["map_encoder"][node.type] + if "heading_state_index" in me_hyp: + heading_state_index = me_hyp["heading_state_index"] + # We have to rotate the map in the opposit direction of the agent to match them + if ( + type(heading_state_index) is list + ): # infer from velocity or heading vector + heading_angle = ( + -np.arctan2( + x[-1, heading_state_index[1]], x[-1, heading_state_index[0]] + ) + * 180 + / np.pi + ) + else: + heading_angle = -x[-1, heading_state_index] * 180 / np.pi + else: + heading_angle = None + + scene_map = scene.map[node.type] + map_point = x[-1, :2] + + patch_size = hyperparams["map_encoder"][node.type]["patch_size"] + map_tuple = (scene_map, map_point, heading_angle, patch_size) + + data_tuple = ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_tuple, + neighbors_future_data, + plan_data, + ) + if is_closed_loop: + return (data_tuple, (neighbors_data_not_st, logged_robot_data, robot_i)) + return data_tuple + + +def get_timesteps_data( + env, + scene, + t, + node_type, + state, + pred_state, + edge_types, + min_ht, + max_ht, + min_ft, + max_ft, + hyperparams, +): + """ + Puts together the inputs for ALL nodes in a given scene and timestep in it. + + :param env: Environment + :param scene: Scene + :param t: Timestep in scene + :param node_type: Node Type of nodes for which the data shall be pre-processed + :param state: Specification of the node state + :param pred_state: Specification of the prediction state + :param edge_types: List of all Edge Types for which neighbors are pre-processed + :param max_ht: Maximum history timesteps + :param max_ft: Maximum future timesteps (prediction horizon) + :param hyperparams: Model hyperparameters + :return: + """ + nodes_per_ts = scene.present_nodes( + t, + type=node_type, + min_history_timesteps=min_ht, + min_future_timesteps=max_ft, + return_robot=not hyperparams["incl_robot_node"], + ) + # Filter fields not in data + state = { + nk: {k: v for k, v in ndict.items() if k != "augment"} + for nk, ndict in state.items() + } + + batch = list() + nodes = list() + out_timesteps = list() + for timestep in nodes_per_ts.keys(): + scene_graph = scene.get_scene_graph( + timestep, + env.attention_radius, + hyperparams["edge_addition_filter"], + hyperparams["edge_removal_filter"], + ) + present_nodes = nodes_per_ts[timestep] + for node in present_nodes: + nodes.append(node) + out_timesteps.append(timestep) + batch.append( + get_node_timestep_data( + env, + scene, + timestep, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps=env.nusc_maps, + scene_graph=scene_graph, + ) + ) + if len(out_timesteps) == 0: + return None + return collate(batch), nodes, out_timesteps diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/__init__.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/__init__.py new file mode 100644 index 0000000..8f506d4 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/__init__.py @@ -0,0 +1,4 @@ +from .dynamic import Dynamic +from .single_integrator import SingleIntegrator +from .unicycle import Unicycle +from .linear import Linear diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/dynamic.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/dynamic.py new file mode 100644 index 0000000..6b03e13 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/dynamic.py @@ -0,0 +1,30 @@ + + +class Dynamic(object): + def __init__(self, dt, dyn_limits, device, model_registrar, xz_size, node_type): + self.dt = dt + self.device = device + self.dyn_limits = dyn_limits + self.initial_conditions = None + self.model_registrar = model_registrar + self.node_type = node_type + self.init_constants() + self.create_graph(xz_size) + + def set_initial_condition(self, init_con): + self.initial_conditions = init_con + + def init_constants(self): + pass + + def create_graph(self, xz_size): + pass + + def integrate_samples(self, s, x): + raise NotImplementedError + + def integrate_distribution(self, dist, x): + raise NotImplementedError + + def create_graph(self, xz_size): + pass diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/linear.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/linear.py new file mode 100644 index 0000000..ebceafd --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/linear.py @@ -0,0 +1,12 @@ +from diffstack.modules.predictors.trajectron_utils.model.dynamics import Dynamic + + +class Linear(Dynamic): + def init_constants(self): + pass + + def integrate_samples(self, v, x): + return v + + def integrate_distribution(self, v_dist, x): + return v_dist \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/single_integrator.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/single_integrator.py new file mode 100644 index 0000000..4d4bd51 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/single_integrator.py @@ -0,0 +1,64 @@ +import torch +from diffstack.modules.predictors.trajectron_utils.model.dynamics import Dynamic +from diffstack.utils.utils import block_diag +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D + + +class SingleIntegrator(Dynamic): + def init_constants(self): + self.F = torch.eye(4, device=self.device, dtype=torch.float32) + self.F[0:2, 2:] = torch.eye(2, device=self.device, dtype=torch.float32) * self.dt + self.F_t = self.F.transpose(-2, -1) + + def integrate_samples(self, v, x=None): + """ + Integrates deterministic samples of velocity. + + :param v: Velocity samples + :param x: Not used for SI. + :return: Position samples + """ + p_0 = self.initial_conditions['pos'].unsqueeze(1) + return torch.cumsum(v, dim=2) * self.dt + p_0 + + def integrate_distribution(self, v_dist, x=None): + r""" + Integrates the GMM velocity distribution to a distribution over position. + The Kalman Equations are used. + + .. math:: \mu_{t+1} =\textbf{F} \mu_{t} + + .. math:: \mathbf{\Sigma}_{t+1}={\textbf {F}} \mathbf{\Sigma}_{t} {\textbf {F}}^{T} + + .. math:: + \textbf{F} = \left[ + \begin{array}{cccc} + \sigma_x^2 & \rho_p \sigma_x \sigma_y & 0 & 0 \\ + \rho_p \sigma_x \sigma_y & \sigma_y^2 & 0 & 0 \\ + 0 & 0 & \sigma_{v_x}^2 & \rho_v \sigma_{v_x} \sigma_{v_y} \\ + 0 & 0 & \rho_v \sigma_{v_x} \sigma_{v_y} & \sigma_{v_y}^2 \\ + \end{array} + \right]_{t} + + :param v_dist: Joint GMM Distribution over velocity in x and y direction. + :param x: Not used for SI. + :return: Joint GMM Distribution over position in x and y direction. + """ + p_0 = self.initial_conditions['pos'].unsqueeze(1) + ph = v_dist.mus.shape[-3] + sample_batch_dim = list(v_dist.mus.shape[0:2]) + pos_dist_sigma_matrix_list = [] + + pos_mus = p_0[:, None] + torch.cumsum(v_dist.mus, dim=2) * self.dt + + vel_dist_sigma_matrix = v_dist.get_covariance_matrix() + pos_dist_sigma_matrix_t = torch.zeros(sample_batch_dim + [v_dist.components, 2, 2], device=self.device) + + for t in range(ph): + vel_sigma_matrix_t = vel_dist_sigma_matrix[:, :, t] + full_sigma_matrix_t = block_diag([pos_dist_sigma_matrix_t, vel_sigma_matrix_t]) + pos_dist_sigma_matrix_t = self.F[..., :2, :].matmul(full_sigma_matrix_t.matmul(self.F_t)[..., :2]) + pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t) + + pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2) + return GMM2D.from_log_pis_mus_cov_mats(v_dist.log_pis, pos_mus, pos_dist_sigma_matrix) \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py new file mode 100644 index 0000000..c7a366a --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +from diffstack.modules.predictors.trajectron_utils.model.dynamics import Dynamic +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +from diffstack.utils.utils import block_diag + + +class Unicycle(Dynamic): + def init_constants(self): + self.F_s = torch.eye(4, device=self.device, dtype=torch.float32) + self.F_s[0:2, 2:] = ( + torch.eye(2, device=self.device, dtype=torch.float32) * self.dt + ) + self.F_s_t = self.F_s.transpose(-2, -1) + + def create_graph(self, xz_size): + model_if_absent = nn.Linear(xz_size + 1, 1) + self.p0_model = self.model_registrar.get_model( + f"{self.node_type}/unicycle_initializer", model_if_absent + ) + + def dynamic(self, x, u): + """ + :param x: + :param u: + :return: + """ + x_p = x[0] + y_p = x[1] + phi = x[2] + v = x[3] + dphi = u[0] + a = u[1] + + mask = torch.abs(dphi) <= 1e-2 + dphi = ~mask * dphi + (mask) * 1 + + phi_p_omega_dt = phi + dphi * self.dt + dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi + dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi + + d1 = torch.stack( + [ + ( + x_p + + (a / dphi) * dcos_domega + + v * dsin_domega + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt + ), + ( + y_p + - v * dcos_domega + + (a / dphi) * dsin_domega + - (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt + ), + phi + dphi * self.dt, + v + a * self.dt, + ], + dim=0, + ) + d2 = torch.stack( + [ + x_p + + v * torch.cos(phi) * self.dt + + (a / 2) * torch.cos(phi) * self.dt**2, + y_p + + v * torch.sin(phi) * self.dt + + (a / 2) * torch.sin(phi) * self.dt**2, + phi * torch.ones_like(a), + v + a * self.dt, + ], + dim=0, + ) + return torch.where(~mask, d1, d2) + + def integrate_samples(self, control_samples, x=None): + ph = control_samples.shape[-2] + p_0 = self.initial_conditions["pos"].unsqueeze(1) + v_0 = self.initial_conditions["vel"].unsqueeze(1) + + # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. + if p_0.size()[0] != x.size()[0]: + p_0 = p_0.repeat(x.size()[0], 1, 1) + v_0 = v_0.repeat(x.size()[0], 1, 1) + + phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0]) + + phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) + + u = torch.stack([control_samples[..., 0], control_samples[..., 1]], dim=0) + x = torch.stack( + [p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0 + ).squeeze(dim=-1) + + mus_list = [] + for t in range(ph): + x = self.dynamic(x, u[..., t]) + mus_list.append(torch.stack((x[0], x[1], x[2], x[3]), dim=-1)) + + pos_mus = torch.stack(mus_list, dim=2) + return pos_mus + + def compute_control_jacobian(self, sample_batch_dim, components, x, u): + F = torch.zeros( + sample_batch_dim + [components, 4, 2], + device=self.device, + dtype=torch.float32, + ) + + phi = x[2] + v = x[3] + dphi = u[0] + a = u[1] + + mask = torch.abs(dphi) <= 1e-2 + dphi = ~mask * dphi + (mask) * 1 + + phi_p_omega_dt = phi + dphi * self.dt + dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi + dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi + + F[..., 0, 0] = ( + (v / dphi) * torch.cos(phi_p_omega_dt) * self.dt + - (v / dphi) * dsin_domega + - (2 * a / dphi**2) * torch.sin(phi_p_omega_dt) * self.dt + - (2 * a / dphi**2) * dcos_domega + + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt**2 + ) + F[..., 0, 1] = (1 / dphi) * dcos_domega + (1 / dphi) * torch.sin( + phi_p_omega_dt + ) * self.dt + + F[..., 1, 0] = ( + (v / dphi) * dcos_domega + - (2 * a / dphi**2) * dsin_domega + + (2 * a / dphi**2) * torch.cos(phi_p_omega_dt) * self.dt + + (v / dphi) * torch.sin(phi_p_omega_dt) * self.dt + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt**2 + ) + F[..., 1, 1] = (1 / dphi) * dsin_domega - (1 / dphi) * torch.cos( + phi_p_omega_dt + ) * self.dt + + F[..., 2, 0] = self.dt + + F[..., 3, 1] = self.dt + + F_sm = torch.zeros( + sample_batch_dim + [components, 4, 2], + device=self.device, + dtype=torch.float32, + ) + + F_sm[..., 0, 1] = (torch.cos(phi) * self.dt**2) / 2 + + F_sm[..., 1, 1] = (torch.sin(phi) * self.dt**2) / 2 + + F_sm[..., 3, 1] = self.dt + + return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm) + + def compute_jacobian(self, sample_batch_dim, components, x, u): + one = torch.tensor(1) + F = torch.zeros( + sample_batch_dim + [components, 4, 4], + device=self.device, + dtype=torch.float32, + ) + + phi = x[2] + v = x[3] + dphi = u[0] + a = u[1] + + mask = torch.abs(dphi) <= 1e-2 + dphi = ~mask * dphi + (mask) * 1 + + phi_p_omega_dt = phi + dphi * self.dt + dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi + dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi + + F[..., 0, 0] = one + F[..., 1, 1] = one + F[..., 2, 2] = one + F[..., 3, 3] = one + + F[..., 0, 2] = ( + v * dcos_domega + - (a / dphi) * dsin_domega + + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt + ) + F[..., 0, 3] = dsin_domega + + F[..., 1, 2] = ( + v * dsin_domega + + (a / dphi) * dcos_domega + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt + ) + F[..., 1, 3] = -dcos_domega + + F_sm = torch.zeros( + sample_batch_dim + [components, 4, 4], + device=self.device, + dtype=torch.float32, + ) + + F_sm[..., 0, 0] = one + F_sm[..., 1, 1] = one + F_sm[..., 2, 2] = one + F_sm[..., 3, 3] = one + + F_sm[..., 0, 2] = ( + -v * torch.sin(phi) * self.dt - (a * torch.sin(phi) * self.dt**2) / 2 + ) + F_sm[..., 0, 3] = torch.cos(phi) * self.dt + + F_sm[..., 1, 2] = ( + v * torch.cos(phi) * self.dt + (a * torch.cos(phi) * self.dt**2) / 2 + ) + F_sm[..., 1, 3] = torch.sin(phi) * self.dt + + return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm) + + def integrate_distribution(self, control_dist_dphi_a, x): + sample_batch_dim = list(control_dist_dphi_a.mus.shape[0:2]) + ph = control_dist_dphi_a.mus.shape[-3] + p_0 = self.initial_conditions["pos"].unsqueeze(1) + v_0 = self.initial_conditions["vel"].unsqueeze(1) + + # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. + if p_0.size()[0] != x.size()[0]: + p_0 = p_0.repeat(x.size()[0], 1, 1) + v_0 = v_0.repeat(x.size()[0], 1, 1) + + phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0]) + + phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) + + dist_sigma_matrix = control_dist_dphi_a.get_covariance_matrix() + pos_dist_sigma_matrix_t = torch.zeros( + sample_batch_dim + [control_dist_dphi_a.components, 4, 4], + device=self.device, + ) + + u = torch.stack( + [control_dist_dphi_a.mus[..., 0], control_dist_dphi_a.mus[..., 1]], dim=0 + ) + x = torch.stack( + [p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0 + ) + + pos_dist_sigma_matrix_list = [] + mus_list = [] + for t in range(ph): + F_t = self.compute_jacobian( + sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t] + ) + G_t = self.compute_control_jacobian( + sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t] + ) + dist_sigma_matrix_t = dist_sigma_matrix[:, :, t] + pos_dist_sigma_matrix_t = F_t.matmul( + pos_dist_sigma_matrix_t.matmul(F_t.transpose(-2, -1)) + ) + G_t.matmul(dist_sigma_matrix_t.matmul(G_t.transpose(-2, -1))) + pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t[..., :2, :2]) + + x = self.dynamic(x, u[:, :, :, t]) + mus_list.append(torch.stack((x[0], x[1], x[2]), dim=-1)) + + pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2) + pos_mus = torch.stack(mus_list, dim=2) + return GMM2D.from_log_pis_mus_cov_mats( + control_dist_dphi_a.log_pis, pos_mus, pos_dist_sigma_matrix + ) diff --git a/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py b/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py new file mode 100644 index 0000000..fcb5cf6 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py @@ -0,0 +1,1169 @@ +import warnings +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from diffstack.modules.predictors.trajectron_utils.model.components import * +from diffstack.modules.predictors.trajectron_utils.model.model_utils import * +import diffstack.modules.predictors.trajectron_utils.model.dynamics as dynamic_module +from diffstack.modules.predictors.trajectron_utils.environment.scene_graph import DirectedEdge +from typing import Tuple + + +class MultimodalGenerativeCVAE(nn.Module): + def __init__(self, + env, + node_type, + model_registrar, + hyperparams, + device, + edge_types, + log_writer=None): + super(MultimodalGenerativeCVAE, self).__init__() + + self.hyperparams = hyperparams + self.env = env + self.node_type = node_type + self.model_registrar = model_registrar + self.log_writer = log_writer + self.device = device + self.edge_types = [edge_type for edge_type in edge_types if edge_type[0] is node_type] + self.curr_iter = 0 + + self.node_modules = nn.ModuleDict() + + self.min_hl = self.hyperparams['minimum_history_length'] + self.max_hl = self.hyperparams['maximum_history_length'] + self.ph = self.hyperparams['prediction_horizon'] + self.state = self.hyperparams['state'] + self.pred_state = self.hyperparams['pred_state'][node_type] + self.node_type_state_lengths = {k: + int(np.sum([len(entity_dims) for entity_dims in self.state[k].values()])) + for k in self.state.keys()} + self.state_length = self.node_type_state_lengths[node_type] + + if self.hyperparams['incl_robot_node']: + self.robot_state_length = int( + np.sum([len(entity_dims) for entity_dims in self.state[env.robot_type].values()]) + ) + self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()])) + + edge_types_str = [DirectedEdge.get_str_from_types(*edge_type) for edge_type in self.edge_types] + self.create_graphical_model(edge_types_str) + + dynamic_class = getattr(dynamic_module, hyperparams['dynamic'][self.node_type]['name']) + dyn_limits = hyperparams['dynamic'][self.node_type]['limits'] + self.dynamic = dynamic_class(self.env.dt, dyn_limits, device, + self.model_registrar, self.x_size, self.node_type) + + def set_curr_iter(self, curr_iter): + self.curr_iter = curr_iter + + def add_submodule(self, name, model_if_absent): + self.node_modules[name] = self.model_registrar.get_model(name, model_if_absent) + + def clear_submodules(self): + self.node_modules.clear() + + def create_node_models(self): + ############################ + # Node History Encoder # + ############################ + self.add_submodule(self.node_type + '/node_history_encoder', + model_if_absent=nn.LSTM(input_size=self.state_length, + hidden_size=self.hyperparams['enc_rnn_dim_history'], + batch_first=True)) + + ########################### + # Node Future Encoder # + ########################### + # We'll create this here, but then later check if in training mode. + # Based on that, we'll factor this into the computation graph (or not). + self.add_submodule(self.node_type + '/node_future_encoder', + model_if_absent=nn.LSTM(input_size=self.pred_state_length, + hidden_size=self.hyperparams['enc_rnn_dim_future'], + bidirectional=True, + batch_first=True)) + # These are related to how you initialize states for the node future encoder. + self.add_submodule(self.node_type + '/node_future_encoder/initial_h', + model_if_absent=nn.Linear(self.state_length, + self.hyperparams['enc_rnn_dim_future'])) + self.add_submodule(self.node_type + '/node_future_encoder/initial_c', + model_if_absent=nn.Linear(self.state_length, + self.hyperparams['enc_rnn_dim_future'])) + + ############################ + # Robot Future Encoder # + ############################ + # We'll create this here, but then later check if we're next to the robot. + # Based on that, we'll factor this into the computation graph (or not). + if self.hyperparams['incl_robot_node']: + self.add_submodule('robot_future_encoder', + model_if_absent=nn.LSTM(input_size=self.robot_state_length, + hidden_size=self.hyperparams['enc_rnn_dim_future'], + bidirectional=True, + batch_first=True)) + # These are related to how you initialize states for the robot future encoder. + self.add_submodule('robot_future_encoder/initial_h', + model_if_absent=nn.Linear(self.robot_state_length, + self.hyperparams['enc_rnn_dim_future'])) + self.add_submodule('robot_future_encoder/initial_c', + model_if_absent=nn.Linear(self.robot_state_length, + self.hyperparams['enc_rnn_dim_future'])) + + if self.hyperparams['edge_encoding']: + ############################## + # Edge Influence Encoder # + ############################## + # NOTE: The edge influence encoding happens during calls + # to forward or incremental_forward, so we don't create + # a model for it here for the max and sum variants. + if self.hyperparams['edge_influence_combine_method'] == 'bi-rnn': + self.add_submodule(self.node_type + '/edge_influence_encoder', + model_if_absent=nn.LSTM(input_size=self.hyperparams['enc_rnn_dim_edge'], + hidden_size=self.hyperparams['enc_rnn_dim_edge_influence'], + bidirectional=True, + batch_first=True)) + + # Four times because we're trying to mimic a bi-directional + # LSTM's output (which, here, is c and h from both ends). + self.eie_output_dims = 4 * self.hyperparams['enc_rnn_dim_edge_influence'] + + elif self.hyperparams['edge_influence_combine_method'] == 'attention': + # Chose additive attention because of https://arxiv.org/pdf/1703.03906.pdf + # We calculate an attention context vector using the encoded edges as the "encoder" + # (that we attend _over_) + # and the node history encoder representation as the "decoder state" (that we attend _on_). + self.add_submodule(self.node_type + '/edge_influence_encoder', + model_if_absent=AdditiveAttention( + encoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_edge_influence'], + decoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_history'])) + + self.eie_output_dims = self.hyperparams['enc_rnn_dim_edge_influence'] + + ################### + # Map Encoder # + ################### + if self.hyperparams['use_map_encoding']: + if self.node_type in self.hyperparams['map_encoder']: + me_params = self.hyperparams['map_encoder'][self.node_type] + self.add_submodule(self.node_type + '/map_encoder', + model_if_absent=CNNMapEncoder(me_params['map_channels'], + me_params['hidden_channels'], + me_params['output_size'], + me_params['masks'], + me_params['strides'], + me_params['patch_size'])) + + ################################ + # Discrete Latent Variable # + ################################ + self.latent = DiscreteLatent(self.hyperparams, self.device) + + ###################################################################### + # Various Fully-Connected Layers from Encoder to Latent Variable # + ###################################################################### + # Node History Encoder + x_size = self.hyperparams['enc_rnn_dim_history'] + if self.hyperparams['edge_encoding']: + # Edge Encoder + x_size += self.eie_output_dims + if self.hyperparams['incl_robot_node']: + # Future Conditional Encoder + x_size += 4 * self.hyperparams['enc_rnn_dim_future'] + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + # Map Encoder + x_size += self.hyperparams['map_encoder'][self.node_type]['output_size'] + + z_size = self.hyperparams['N'] * self.hyperparams['K'] + + if self.hyperparams['p_z_x_MLP_dims'] is not None: + self.add_submodule(self.node_type + '/p_z_x', + model_if_absent=nn.Linear(x_size, self.hyperparams['p_z_x_MLP_dims'])) + hx_size = self.hyperparams['p_z_x_MLP_dims'] + else: + hx_size = x_size + + self.add_submodule(self.node_type + '/hx_to_z', + model_if_absent=nn.Linear(hx_size, self.latent.z_dim)) + + if self.hyperparams['q_z_xy_MLP_dims'] is not None: + self.add_submodule(self.node_type + '/q_z_xy', + # Node Future Encoder + model_if_absent=nn.Linear(x_size + 4 * self.hyperparams['enc_rnn_dim_future'], + self.hyperparams['q_z_xy_MLP_dims'])) + hxy_size = self.hyperparams['q_z_xy_MLP_dims'] + else: + # Node Future Encoder + hxy_size = x_size + 4 * self.hyperparams['enc_rnn_dim_future'] + + self.add_submodule(self.node_type + '/hxy_to_z', + model_if_absent=nn.Linear(hxy_size, self.latent.z_dim)) + + #################### + # Decoder LSTM # + #################### + if self.hyperparams['incl_robot_node']: + decoder_input_dims = self.pred_state_length + self.robot_state_length + z_size + x_size + else: + decoder_input_dims = self.pred_state_length + z_size + x_size + + self.add_submodule(self.node_type + '/decoder/state_action', + model_if_absent=nn.Sequential( + nn.Linear(self.state_length, self.pred_state_length))) + + self.add_submodule(self.node_type + '/decoder/rnn_cell', + model_if_absent=nn.GRUCell(decoder_input_dims, self.hyperparams['dec_rnn_dim'])) + self.add_submodule(self.node_type + '/decoder/initial_h', + model_if_absent=nn.Linear(z_size + x_size, self.hyperparams['dec_rnn_dim'])) + + ################### + # Decoder GMM # + ################### + self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_pis', + model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'], + self.hyperparams['GMM_components'])) + self.add_submodule(self.node_type + '/decoder/proj_to_GMM_mus', + model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'], + self.hyperparams['GMM_components'] * self.pred_state_length)) + self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_sigmas', + model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'], + self.hyperparams['GMM_components'] * self.pred_state_length)) + self.add_submodule(self.node_type + '/decoder/proj_to_GMM_corrs', + model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'], + self.hyperparams['GMM_components'])) + + self.x_size = x_size + self.z_size = z_size + + def create_edge_models(self, edge_types): + for edge_type in edge_types: + neighbor_state_length = int( + np.sum([len(entity_dims) for entity_dims in self.state[edge_type.split('->')[1]].values()])) + if self.hyperparams['edge_state_combine_method'] == 'pointnet': + self.add_submodule(edge_type + '/pointnet_encoder', + model_if_absent=nn.Sequential( + nn.Linear(self.state_length, 2 * self.state_length), + nn.ReLU(), + nn.Linear(2 * self.state_length, 2 * self.state_length), + nn.ReLU())) + + edge_encoder_input_size = 2 * self.state_length + self.state_length + + elif self.hyperparams['edge_state_combine_method'] == 'attention': + self.add_submodule(self.node_type + '/edge_attention_combine', + model_if_absent=TemporallyBatchedAdditiveAttention( + encoder_hidden_state_dim=self.state_length, + decoder_hidden_state_dim=self.state_length)) + edge_encoder_input_size = self.state_length + neighbor_state_length + + else: + edge_encoder_input_size = self.state_length + neighbor_state_length + + self.add_submodule(edge_type + '/edge_encoder', + model_if_absent=nn.LSTM(input_size=edge_encoder_input_size, + hidden_size=self.hyperparams['enc_rnn_dim_edge'], + batch_first=True)) + + def create_graphical_model(self, edge_types): + """ + Creates or queries all trainable components. + + :param edge_types: List containing strings for all possible edge types for the node type. + :return: None + """ + self.clear_submodules() + + ############################ + # Everything but Edges # + ############################ + self.create_node_models() + + ##################### + # Edge Encoders # + ##################### + if self.hyperparams['edge_encoding']: + self.create_edge_models(edge_types) + + for name, module in self.node_modules.items(): + module.to(self.device) + + def create_new_scheduler(self, name, annealer, annealer_kws, creation_condition=True): + value_scheduler = None + rsetattr(self, name + '_scheduler', value_scheduler) + if creation_condition: + annealer_kws['device'] = self.device + value_annealer = annealer(annealer_kws) + rsetattr(self, name + '_annealer', value_annealer) + + # This is the value that we'll update on each call of + # step_annealers(). + rsetattr(self, name, value_annealer(0).clone().detach()) + dummy_optimizer = optim.Optimizer([rgetattr(self, name)], {'lr': value_annealer(0).clone().detach()}) + rsetattr(self, name + '_optimizer', dummy_optimizer) + + value_scheduler = CustomLR(dummy_optimizer, + value_annealer) + rsetattr(self, name + '_scheduler', value_scheduler) + + self.schedulers.append(value_scheduler) + self.annealed_vars.append(name) + + def set_annealing_params(self): + self.schedulers = list() + self.annealed_vars = list() + + self.create_new_scheduler(name='kl_weight', + annealer=sigmoid_anneal, + annealer_kws={ + 'start': self.hyperparams['kl_weight_start'], + 'finish': self.hyperparams['kl_weight'], + 'center_step': self.hyperparams['kl_crossover'], + 'steps_lo_to_hi': self.hyperparams['kl_crossover'] / self.hyperparams[ + 'kl_sigmoid_divisor'] + }) + + self.create_new_scheduler(name='latent.temp', + annealer=exp_anneal, + annealer_kws={ + 'start': self.hyperparams['tau_init'], + 'finish': self.hyperparams['tau_final'], + 'rate': self.hyperparams['tau_decay_rate'] + }) + + self.create_new_scheduler(name='latent.z_logit_clip', + annealer=sigmoid_anneal, + annealer_kws={ + 'start': self.hyperparams['z_logit_clip_start'], + 'finish': self.hyperparams['z_logit_clip_final'], + 'center_step': self.hyperparams['z_logit_clip_crossover'], + 'steps_lo_to_hi': self.hyperparams['z_logit_clip_crossover'] / self.hyperparams[ + 'z_logit_clip_divisor'] + }, + creation_condition=self.hyperparams['use_z_logit_clipping']) + + def step_annealers(self): + # This should manage all of the step-wise changed + # parameters automatically. + for idx, annealed_var in enumerate(self.annealed_vars): + if rgetattr(self, annealed_var + '_scheduler') is not None: + # First we step the scheduler. + with warnings.catch_warnings(): # We use a dummy optimizer: Warning because no .step() was called on it + warnings.simplefilter("ignore") + rgetattr(self, annealed_var + '_scheduler').step() + + # Then we set the annealed vars' value. + rsetattr(self, annealed_var, rgetattr(self, annealed_var + '_optimizer').param_groups[0]['lr']) + + self.summarize_annealers() + + def summarize_annealers(self): + if self.log_writer is not None: + for annealed_var in self.annealed_vars: + if rgetattr(self, annealed_var) is not None: + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), annealed_var.replace('.', '/')), + rgetattr(self, annealed_var), self.curr_iter) + + def obtain_encoded_tensors(self, + mode, + inputs, + inputs_st, + labels, + labels_st, + first_history_indices, + neighbors, + neighbors_edge_value, + robot, + map) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encodes input and output tensors for node and robot. + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param inputs: Input tensor including the state for each agent over time [bs, t, state]. + :param inputs_st: Standardized input tensor. + :param labels: Label tensor including the label output for each agent over time [bs, t, pred_state]. + :param labels_st: Standardized label tensor. + :param first_history_indices: First timestep (index) in scene for which data is available for a node [bs] + :param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time. + [[bs, t, neighbor state]] + :param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]] + :param robot: Standardized robot state over time. [bs, t, robot_state] + :param map: Tensor of Map information. [bs, channels, x, y] + :return: tuple(x, x_nr_t, y_e, y_r, y, n_s_t0) + WHERE + - x: Encoded input / condition tensor to the CVAE x_e. + - x_r_t: Robot state (if robot is in scene). + - y_e: Encoded label / future of the node. + - y_r: Encoded future of the robot. + - y: Label / future of the node. + - n_s_t0: Standardized current state of the node. + """ + + x, x_r_t, y_e, y_r, y = None, None, None, None, None + initial_dynamics = dict() + + batch_size = inputs.shape[0] + + ######################################### + # Provide basic information to encoders # + ######################################### + node_history = inputs + node_present_state = inputs[:, -1] + node_pos = inputs[:, -1, 0:2] + node_vel = inputs[:, -1, 2:4] + + node_history_st = inputs_st + node_present_state_st = inputs_st[:, -1] + node_pos_st = inputs_st[:, -1, 0:2] + node_vel_st = inputs_st[:, -1, 2:4] + + n_s_t0 = node_present_state_st + + initial_dynamics['pos'] = node_pos + initial_dynamics['vel'] = node_vel + + self.dynamic.set_initial_condition(initial_dynamics) + + if self.hyperparams['incl_robot_node']: + x_r_t, y_r = robot[..., 0, :], robot[..., 1:, :] + + ################## + # Encode History # + ################## + node_history_encoded = self.encode_node_history(mode, + node_history_st, + first_history_indices) + + ################## + # Encode Present # + ################## + node_present = node_present_state_st # [bs, state_dim] + + ################## + # Encode Future # + ################## + if mode != ModeKeys.PREDICT: + y = labels_st + + ############################## + # Encode Node Edges per Type # + ############################## + if self.hyperparams['edge_encoding']: + node_edges_encoded = list() + for edge_type in self.edge_types: + # Encode edges for given edge type + encoded_edges_type = self.encode_edge(mode, + node_history, + node_history_st, + edge_type, + neighbors[edge_type], + neighbors_edge_value[edge_type], + first_history_indices) + node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim] + ##################### + # Encode Node Edges # + ##################### + total_edge_influence = self.encode_total_edge_influence(mode, + node_edges_encoded, + node_history_encoded, + batch_size) + + ################ + # Map Encoding # + ################ + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + if self.log_writer and (self.curr_iter + 1) % 500 == 0: + map_clone = map.clone() + map_patch = self.hyperparams['map_encoder'][self.node_type]['patch_size'] + map_clone[:, :, map_patch[1] - 5:map_patch[1] + 5, map_patch[0] - 5:map_patch[0] + 5] = 1. + self.log_writer.add_images(f"{self.node_type}/cropped_maps", map_clone, + self.curr_iter, dataformats='NCWH') + + encoded_map = self.node_modules[self.node_type + '/map_encoder'](map * 2. - 1., (mode == ModeKeys.TRAIN)) + do = self.hyperparams['map_encoder'][self.node_type]['dropout'] + encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN)) + + ###################################### + # Concatenate Encoder Outputs into x # + ###################################### + x_concat_list = list() + + # Every node has an edge-influence encoder (which could just be zero). + if self.hyperparams['edge_encoding']: + x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim] + + # Every node has a history encoder. + x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history] + + if self.hyperparams['incl_robot_node']: + robot_future_encoder = self.encode_robot_future(mode, x_r_t, y_r) + x_concat_list.append(robot_future_encoder) + + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + if self.log_writer: + self.log_writer.add_scalar(f"{self.node_type}/encoded_map_max", + torch.max(torch.abs(encoded_map)), self.curr_iter) + x_concat_list.append(encoded_map) + + x = torch.cat(x_concat_list, dim=1) + + if mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL: + y_e = self.encode_node_future(mode, node_present, y) + + return x, x_r_t, y_e, y_r, y, n_s_t0 + + def encode_node_history(self, mode, node_hist, first_history_indices): + """ + Encodes the nodes history. + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param node_hist: Historic and current state of the node. [bs, mhl, state] + :param first_history_indices: First timestep (index) in scene for which data is available for a node [bs] + :return: Encoded node history tensor. [bs, enc_rnn_dim] + """ + outputs, _ = run_lstm_on_variable_length_seqs(self.node_modules[self.node_type + '/node_history_encoder'], + original_seqs=node_hist, + lower_indices=first_history_indices) + + outputs = F.dropout(outputs, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim] + + last_index_per_sequence = -(first_history_indices + 1) + + return outputs[torch.arange(first_history_indices.shape[0]), last_index_per_sequence] + + def encode_edge(self, + mode, + node_history, + node_history_st, + edge_type, + neighbors, + neighbors_edge_value, + first_history_indices): + + max_hl = self.hyperparams['maximum_history_length'] + + edge_states_list = list() # list of [#of neighbors, max_ht, state_dim] + for i, neighbor_states in enumerate(neighbors): # Get neighbors for timestep in batch + if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary? + neighbor_state_length = self.node_type_state_lengths[edge_type[1]] + edge_states_list.append(torch.zeros((1, max_hl + 1, neighbor_state_length), device=self.device)) + else: + edge_states_list.append(torch.stack(neighbor_states, dim=0).to(self.device)) + + if self.hyperparams['edge_state_combine_method'] == 'sum': + # Used in Structural-RNN to combine edges as well. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_value in neighbors_edge_value: + op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_value.to(self.device), + dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + elif self.hyperparams['edge_state_combine_method'] == 'max': + # Used in NLP, e.g. max over word embeddings in a sentence. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_value in neighbors_edge_value: + op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_value.to(self.device), + dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + elif self.hyperparams['edge_state_combine_method'] == 'mean': + # Used in NLP, e.g. mean over word embeddings in a sentence. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_value in neighbors_edge_value: + op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_value.to(self.device), + dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + joint_history = torch.cat([combined_neighbors, node_history_st], dim=-1) + + outputs, _ = run_lstm_on_variable_length_seqs( + self.node_modules[DirectedEdge.get_str_from_types(*edge_type) + '/edge_encoder'], + original_seqs=joint_history, + lower_indices=first_history_indices + ) + + outputs = F.dropout(outputs, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim] + + last_index_per_sequence = -(first_history_indices + 1) + ret = outputs[torch.arange(last_index_per_sequence.shape[0]), last_index_per_sequence] + if self.hyperparams['dynamic_edges'] == 'yes': + return ret * combined_edge_masks + else: + return ret + + def encode_total_edge_influence(self, mode, encoded_edges, node_history_encoder, batch_size): + if self.hyperparams['edge_influence_combine_method'] == 'sum': + stacked_encoded_edges = torch.stack(encoded_edges, dim=0) + combined_edges = torch.sum(stacked_encoded_edges, dim=0) + + elif self.hyperparams['edge_influence_combine_method'] == 'mean': + stacked_encoded_edges = torch.stack(encoded_edges, dim=0) + combined_edges = torch.mean(stacked_encoded_edges, dim=0) + + elif self.hyperparams['edge_influence_combine_method'] == 'max': + stacked_encoded_edges = torch.stack(encoded_edges, dim=0) + combined_edges = torch.max(stacked_encoded_edges, dim=0) + + elif self.hyperparams['edge_influence_combine_method'] == 'bi-rnn': + if len(encoded_edges) == 0: + combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device) + + else: + # axis=1 because then we get size [batch_size, max_time, depth] + encoded_edges = torch.stack(encoded_edges, dim=1) + + _, state = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges) + combined_edges = unpack_RNN_state(state) + combined_edges = F.dropout(combined_edges, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + elif self.hyperparams['edge_influence_combine_method'] == 'attention': + # Used in Social Attention (https://arxiv.org/abs/1710.04689) + if len(encoded_edges) == 0: + combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device) + + else: + # axis=1 because then we get size [batch_size, max_time, depth] + encoded_edges = torch.stack(encoded_edges, dim=1) + combined_edges, _ = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges, + node_history_encoder) + combined_edges = F.dropout(combined_edges, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + return combined_edges + + def encode_node_future(self, mode, node_present, node_future) -> torch.Tensor: + """ + Encodes the node future (during training) using a bi-directional LSTM + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param node_present: Current state of the node. [bs, state] + :param node_future: Future states of the node. [bs, ph, state] + :return: Encoded future. + """ + initial_h_model = self.node_modules[self.node_type + '/node_future_encoder/initial_h'] + initial_c_model = self.node_modules[self.node_type + '/node_future_encoder/initial_c'] + + # Here we're initializing the forward hidden states, + # but zeroing the backward ones. + initial_h = initial_h_model(node_present) + initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0) + + initial_c = initial_c_model(node_present) + initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0) + + initial_state = (initial_h, initial_c) + + _, state = self.node_modules[self.node_type + '/node_future_encoder'](node_future, initial_state) + state = unpack_RNN_state(state) + state = F.dropout(state, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + return state + + def encode_robot_future(self, mode, robot_present, robot_future) -> torch.Tensor: + """ + Encodes the robot future (during training) using a bi-directional LSTM + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param robot_present: Current state of the robot. [bs, state] + :param robot_future: Future states of the robot. [bs, ph, state] + :return: Encoded future. + """ + initial_h_model = self.node_modules['robot_future_encoder/initial_h'] + initial_c_model = self.node_modules['robot_future_encoder/initial_c'] + + # Here we're initializing the forward hidden states, + # but zeroing the backward ones. + initial_h = initial_h_model(robot_present) + initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0) + + initial_c = initial_c_model(robot_present) + initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0) + + initial_state = (initial_h, initial_c) + + _, state = self.node_modules['robot_future_encoder'](robot_future, initial_state) + state = unpack_RNN_state(state) + state = F.dropout(state, + p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + return state + + def q_z_xy(self, mode, x, y_e) -> torch.Tensor: + r""" + .. math:: q_\phi(z \mid \mathbf{x}_i, \mathbf{y}_i) + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param x: Input / Condition tensor. + :param y_e: Encoded future tensor. + :return: Latent distribution of the CVAE. + """ + xy = torch.cat([x, y_e], dim=1) + + if self.hyperparams['q_z_xy_MLP_dims'] is not None: + dense = self.node_modules[self.node_type + '/q_z_xy'] + h = F.dropout(F.relu(dense(xy)), + p=1. - self.hyperparams['MLP_dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + else: + h = xy + + to_latent = self.node_modules[self.node_type + '/hxy_to_z'] + return self.latent.dist_from_h(to_latent(h), mode) + + def p_z_x(self, mode, x): + r""" + .. math:: p_\theta(z \mid \mathbf{x}_i) + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param x: Input / Condition tensor. + :return: Latent distribution of the CVAE. + """ + if self.hyperparams['p_z_x_MLP_dims'] is not None: + dense = self.node_modules[self.node_type + '/p_z_x'] + h = F.dropout(F.relu(dense(x)), + p=1. - self.hyperparams['MLP_dropout_keep_prob'], + training=(mode == ModeKeys.TRAIN)) + + else: + h = x + + to_latent = self.node_modules[self.node_type + '/hx_to_z'] + return self.latent.dist_from_h(to_latent(h), mode) + + def project_to_GMM_params(self, tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Projects tensor to parameters of a GMM with N components and D dimensions. + + :param tensor: Input tensor. + :return: tuple(log_pis, mus, log_sigmas, corrs) + WHERE + - log_pis: Weight (logarithm) of each GMM component. [N] + - mus: Mean of each GMM component. [N, D] + - log_sigmas: Standard Deviation (logarithm) of each GMM component. [N, D] + - corrs: Correlation between the GMM components. [N] + """ + log_pis = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_pis'](tensor) + mus = self.node_modules[self.node_type + '/decoder/proj_to_GMM_mus'](tensor) + log_sigmas = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_sigmas'](tensor) + corrs = torch.tanh(self.node_modules[self.node_type + '/decoder/proj_to_GMM_corrs'](tensor)) + return log_pis, mus, log_sigmas, corrs + + def p_y_xz(self, mode, x, x_nr_t, y_r, n_s_t0, z_stacked, prediction_horizon, + num_samples, num_components=1, gmm_mode=False): + r""" + .. math:: p_\psi(\mathbf{y}_i \mid \mathbf{x}_i, z) + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param x: Input / Condition tensor. + :param x_nr_t: Joint state of node and robot (if robot is in scene). + :param y: Future tensor. + :param y_r: Encoded future tensor. + :param n_s_t0: Standardized current state of the node. + :param z_stacked: Stacked latent state. [num_samples_z * num_samples_gmm, bs, latent_state] + :param prediction_horizon: Number of prediction timesteps. + :param num_samples: Number of samples from the latent space. + :param num_components: Number of GMM components. + :param gmm_mode: If True: The mode of the GMM is sampled. + :return: GMM2D. If mode is Predict, also samples from the GMM. + """ + ph = prediction_horizon + pred_dim = self.pred_state_length + + z = torch.reshape(z_stacked, (-1, self.latent.z_dim)) + zx = torch.cat([z, x.repeat(num_samples * num_components, 1)], dim=1) + + cell = self.node_modules[self.node_type + '/decoder/rnn_cell'] + initial_h_model = self.node_modules[self.node_type + '/decoder/initial_h'] + + initial_state = initial_h_model(zx) + + log_pis, mus, log_sigmas, corrs, a_sample = [], [], [], [], [] + + # Infer initial action state for node from current state + a_0 = self.node_modules[self.node_type + '/decoder/state_action'](n_s_t0) + + state = initial_state + if self.hyperparams['incl_robot_node']: + input_ = torch.cat([zx, + a_0.repeat(num_samples * num_components, 1), + x_nr_t.repeat(num_samples * num_components, 1)], dim=1) + else: + input_ = torch.cat([zx, a_0.repeat(num_samples * num_components, 1)], dim=1) + + for j in range(ph): + h_state = cell(input_, state) + log_pi_t, mu_t, log_sigma_t, corr_t = self.project_to_GMM_params(h_state) + + gmm = GMM2D(log_pi_t, mu_t, log_sigma_t, corr_t) # [k;bs, pred_dim] + + if mode == ModeKeys.PREDICT and gmm_mode: + a_t = gmm.mode() + else: + a_t = gmm.rsample() + + if num_components > 1: + if mode == ModeKeys.PREDICT: + log_pis.append(self.latent.p_dist.logits.repeat(num_samples, 1, 1)) + else: + log_pis.append(self.latent.q_dist.logits.repeat(num_samples, 1, 1)) + else: + log_pis.append( + torch.ones_like(corr_t.reshape(num_samples, num_components, -1).permute(0, 2, 1).reshape(-1, 1)) + ) + + mus.append( + mu_t.reshape( + num_samples, num_components, -1, 2 + ).permute(0, 2, 1, 3).reshape(-1, 2 * num_components) + ) + log_sigmas.append( + log_sigma_t.reshape( + num_samples, num_components, -1, 2 + ).permute(0, 2, 1, 3).reshape(-1, 2 * num_components)) + corrs.append( + corr_t.reshape( + num_samples, num_components, -1 + ).permute(0, 2, 1).reshape(-1, num_components)) + + if self.hyperparams['incl_robot_node']: + dec_inputs = [zx, a_t, y_r[:, j].repeat(num_samples * num_components, 1)] + else: + dec_inputs = [zx, a_t] + input_ = torch.cat(dec_inputs, dim=1) + state = h_state + + log_pis = torch.stack(log_pis, dim=1) + mus = torch.stack(mus, dim=1) + log_sigmas = torch.stack(log_sigmas, dim=1) + corrs = torch.stack(corrs, dim=1) + + a_dist = GMM2D(torch.reshape(log_pis, [num_samples, -1, ph, num_components]), + torch.reshape(mus, [num_samples, -1, ph, num_components * pred_dim]), + torch.reshape(log_sigmas, [num_samples, -1, ph, num_components * pred_dim]), + torch.reshape(corrs, [num_samples, -1, ph, num_components])) + + if self.hyperparams['dynamic'][self.node_type]['distribution']: + y_dist = self.dynamic.integrate_distribution(a_dist, x) + else: + y_dist = a_dist + + if mode == ModeKeys.PREDICT: + if gmm_mode: + a_sample = a_dist.mode() + else: + a_sample = a_dist.rsample() + sampled_future = self.dynamic.integrate_samples(a_sample, x) + return y_dist, sampled_future + else: + return y_dist + + def encoder(self, mode, x, y_e, num_samples=None): + """ + Encoder of the CVAE. + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param x: Input / Condition tensor. + :param y_e: Encoded future tensor. + :param num_samples: Number of samples from the latent space during Prediction. + :return: tuple(z, kl_obj) + WHERE + - z: Samples from the latent space. + - kl_obj: KL Divergenze between q and p + """ + if mode == ModeKeys.TRAIN: + sample_ct = self.hyperparams['k'] + elif mode == ModeKeys.EVAL: + sample_ct = self.hyperparams['k_eval'] + elif mode == ModeKeys.PREDICT: + sample_ct = num_samples + if num_samples is None: + raise ValueError("num_samples cannot be None with mode == PREDICT.") + + self.latent.q_dist = self.q_z_xy(mode, x, y_e) + self.latent.p_dist = self.p_z_x(mode, x) + + z = self.latent.sample_q(sample_ct, mode) + + if mode == ModeKeys.TRAIN: + kl_obj = self.latent.kl_q_p(self.log_writer, '%s' % str(self.node_type), self.curr_iter) + if self.log_writer is not None: + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'kl'), kl_obj, self.curr_iter) + else: + kl_obj = None + + return z, kl_obj + + def decoder(self, mode, x, x_nr_t, y, y_r, n_s_t0, z, labels, prediction_horizon, num_samples, ret_dist=False): + """ + Decoder of the CVAE. + + :param mode: Mode in which the model is operated. E.g. Train, Eval, Predict. + :param x: Input / Condition tensor. + :param x: Input / Condition tensor. + :param x_nr_t: Joint state of node and robot (if robot is in scene). + :param y: Future tensor. + :param y_r: Encoded future tensor. + :param n_s_t0: Standardized current state of the node. + :param z: Stacked latent state. + :param prediction_horizon: Number of prediction timesteps. + :param num_samples: Number of samples from the latent space. + :return: Log probability of y over p. + """ + + num_components = self.hyperparams['N'] * self.hyperparams['K'] + y_dist = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z, + prediction_horizon, num_samples, num_components=num_components) + log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max']) + if self.hyperparams['log_histograms'] and self.log_writer is not None: + self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_yt_xz'), log_p_yt_xz, self.curr_iter) + + log_p_y_xz = torch.sum(log_p_yt_xz, dim=2) + if ret_dist: + return log_p_y_xz, y_dist + else: + return log_p_y_xz + + def forward(self, **kwargs): + return self.train_loss(**kwargs) + + def train_loss(self, + inputs, + inputs_st, + first_history_indices, + labels, + labels_st, + neighbors, + neighbors_edge_value, + robot, + map, + prediction_horizon, + ret_dist=False, + loss_weights=None, + ) -> torch.Tensor: + """ + Calculates the training loss for a batch. + + :param inputs: Input tensor including the state for each agent over time [bs, t, state]. + :param inputs_st: Standardized input tensor. + :param first_history_indices: First timestep (index) in scene for which data is available for a node [bs] + :param labels: Label tensor including the label output for each agent over time [bs, t, pred_state]. + :param labels_st: Standardized label tensor. + :param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time. + [[bs, t, neighbor state]] + :param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]] + :param robot: Standardized robot state over time. [bs, t, robot_state] + :param map: Tensor of Map information. [bs, channels, x, y] + :param prediction_horizon: Number of prediction timesteps. + :return: Scalar tensor -> nll loss + """ + mode = ModeKeys.TRAIN + + x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode, + inputs=inputs, + inputs_st=inputs_st, + labels=labels, + labels_st=labels_st, + first_history_indices=first_history_indices, + neighbors=neighbors, + neighbors_edge_value=neighbors_edge_value, + robot=robot, + map=map) + + z, kl = self.encoder(mode, x, y_e) + log_p_y_xz, y_dist = self.decoder(mode, x, x_nr_t, y, y_r, n_s_t0, z, + labels, # Loss is calculated on unstandardized label + prediction_horizon, + self.hyperparams['k'], + ret_dist=True) + + log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs] + if loss_weights is not None: + # Weighted sum over batch. Weights are expected to sum to 1. + log_likelihood = torch.sum(log_p_y_xz_mean * loss_weights, dim=0) + else: + log_likelihood = torch.mean(log_p_y_xz_mean) + + mutual_inf_q = mutual_inf_mc(self.latent.q_dist) + mutual_inf_p = mutual_inf_mc(self.latent.p_dist) + + ELBO = log_likelihood - self.kl_weight * kl + 1. * mutual_inf_p + loss = -ELBO + + if self.hyperparams['log_histograms'] and self.log_writer is not None: + self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_y_xz'), + log_p_y_xz_mean, + self.curr_iter) + + if self.log_writer is not None: + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_q'), + mutual_inf_q, + self.curr_iter) + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_p'), + mutual_inf_p, + self.curr_iter) + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'log_likelihood'), + log_likelihood, + self.curr_iter) + self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'loss'), + loss, + self.curr_iter) + if self.hyperparams['log_histograms']: + self.latent.summarize_for_tensorboard(self.log_writer, str(self.node_type), self.curr_iter) + + if ret_dist: + return loss, y_dist, (x, ) + else: + return loss + + def eval_loss(self, + inputs, + inputs_st, + first_history_indices, + labels, + labels_st, + neighbors, + neighbors_edge_value, + robot, + map, + prediction_horizon) -> torch.Tensor: + """ + Calculates the evaluation loss for a batch. + + :param inputs: Input tensor including the state for each agent over time [bs, t, state]. + :param inputs_st: Standardized input tensor. + :param first_history_indices: First timestep (index) in scene for which data is available for a node [bs] + :param labels: Label tensor including the label output for each agent over time [bs, t, pred_state]. + :param labels_st: Standardized label tensor. + :param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time. + [[bs, t, neighbor state]] + :param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]] + :param robot: Standardized robot state over time. [bs, t, robot_state] + :param map: Tensor of Map information. [bs, channels, x, y] + :param prediction_horizon: Number of prediction timesteps. + :return: tuple(nll_q_is, nll_p, nll_exact, nll_sampled) + """ + + mode = ModeKeys.EVAL + + x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode, + inputs=inputs, + inputs_st=inputs_st, + labels=labels, + labels_st=labels_st, + first_history_indices=first_history_indices, + neighbors=neighbors, + neighbors_edge_value=neighbors_edge_value, + robot=robot, + map=map) + + num_components = self.hyperparams['N'] * self.hyperparams['K'] + ### Importance sampled NLL estimate + z, _ = self.encoder(mode, x, y_e) # [k_eval, nbs, N*K] + z = self.latent.sample_p(1, mode, full_dist=True) + y_dist, _ = self.p_y_xz(ModeKeys.PREDICT, x, x_nr_t, y_r, n_s_t0, z, + prediction_horizon, num_samples=1, num_components=num_components) + # We use unstandardized labels to compute the loss + log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max']) + log_p_y_xz = torch.sum(log_p_yt_xz, dim=2) + log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs] + log_likelihood = torch.mean(log_p_y_xz_mean) + nll = -log_likelihood + + return nll + + def predict(self, + inputs, + inputs_st, + first_history_indices, + neighbors, + neighbors_edge_value, + robot, + map, + prediction_horizon, + num_samples, + z_mode=False, + gmm_mode=False, + full_dist=True, + all_z_sep=False, + output_dists=False, + output_extra=False, + ): + """ + Predicts the future of a batch of nodes. + + :param inputs: Input tensor including the state for each agent over time [bs, t, state]. + :param inputs_st: Standardized input tensor. + :param first_history_indices: First timestep (index) in scene for which data is available for a node [bs] + :param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time. + [[bs, t, neighbor state]] + :param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]] + :param robot: Standardized robot state over time. [bs, t, robot_state] + :param map: Tensor of Map information. [bs, channels, x, y] + :param prediction_horizon: Number of prediction timesteps. + :param num_samples: Number of samples from the latent space. + :param z_mode: If True: Select the most likely latent state. + :param gmm_mode: If True: The mode of the GMM is sampled. + :param all_z_sep: Samples each latent mode individually without merging them into a GMM. + :param full_dist: Samples all latent states and merges them into a GMM as output. + :return: + """ + mode = ModeKeys.PREDICT + + x, x_nr_t, _, y_r, _, n_s_t0 = self.obtain_encoded_tensors(mode=mode, + inputs=inputs, + inputs_st=inputs_st, + labels=None, + labels_st=None, + first_history_indices=first_history_indices, + neighbors=neighbors, + neighbors_edge_value=neighbors_edge_value, + robot=robot, + map=map) + + self.latent.p_dist = self.p_z_x(mode, x) + z, num_samples, num_components = self.latent.sample_p(num_samples, + mode, + most_likely_z=z_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) + + y_dist, our_sampled_future = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z, + prediction_horizon, + num_samples, + num_components, + gmm_mode) + if output_extra: + return y_dist, our_sampled_future, (x, ) + if output_dists: + return y_dist, our_sampled_future + else: + return our_sampled_future diff --git a/diffstack/modules/predictors/trajectron_utils/model/model_registrar.py b/diffstack/modules/predictors/trajectron_utils/model/model_registrar.py new file mode 100644 index 0000000..6446f9d --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/model_registrar.py @@ -0,0 +1,70 @@ +import os +import torch +import torch.nn as nn + + +def get_model_device(model): + return next(model.parameters()).device + + +class ModelRegistrar(nn.Module): + def __init__(self, model_dir, device): + super(ModelRegistrar, self).__init__() + self.model_dict = nn.ModuleDict() + self.model_dir = model_dir + self.device = device + + def forward(self): + raise NotImplementedError('Although ModelRegistrar is a nn.Module, it is only to store parameters.') + + def get_model(self, name, model_if_absent=None): + # 4 cases: name in self.model_dict and model_if_absent is None (OK) + # name in self.model_dict and model_if_absent is not None (OK) + # name not in self.model_dict and model_if_absent is not None (OK) + # name not in self.model_dict and model_if_absent is None (NOT OK) + + if name in self.model_dict: + return self.model_dict[name] + + elif model_if_absent is not None: + self.model_dict[name] = model_if_absent.to(self.device) + return self.model_dict[name] + + else: + raise ValueError(f'{name} was never initialized in this Registrar!') + + def get_name_match(self, name): + ret_model_list = nn.ModuleList() + for key in self.model_dict.keys(): + if name in key: + ret_model_list.append(self.model_dict[key]) + return ret_model_list + + def get_all_but_name_match(self, name): + ret_model_list = nn.ModuleList() + for key in self.model_dict.keys(): + if name not in key: + ret_model_list.append(self.model_dict[key]) + return ret_model_list + + def print_model_names(self): + print(self.model_dict.keys()) + + def save_models(self, curr_iter): + # Create the model directiory if it's not present. + save_path = os.path.join(self.model_dir, + 'model_registrar-%d.pt' % curr_iter) + + torch.save(self.model_dict, save_path) + + def load_models(self, iter_num): + self.model_dict.clear() + + save_path = os.path.join(self.model_dir, + 'model_registrar-%d.pt' % iter_num) + + print('') + print('Loading from ' + save_path) + self.model_dict = torch.load(save_path, map_location=self.device) + print('Loaded!') + print('') diff --git a/diffstack/modules/predictors/trajectron_utils/model/model_utils.py b/diffstack/modules/predictors/trajectron_utils/model/model_utils.py new file mode 100644 index 0000000..0d516b3 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/model_utils.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.utils.rnn as rnn +from enum import Enum +import functools +import numpy as np +import math + + +class ModeKeys(Enum): + TRAIN = 1 + EVAL = 2 + PREDICT = 3 + + +def cyclical_lr(stepsize, min_lr=3e-4, max_lr=3e-3, decay=1.): + # Lambda function to calculate the LR + lr_lambda = lambda it: min_lr + (max_lr - min_lr) * relative(it, stepsize) * decay**it + + # Additional function to see where on the cycle we are + def relative(it, stepsize): + cycle = math.floor(1 + it / (2 * stepsize)) + x = abs(it / stepsize - 2 * cycle + 1) + return max(0, (1 - x)) + + return lr_lambda + + +def to_one_hot(labels, n_labels): + return torch.eye(n_labels, device=labels.device)[labels] + + +def exp_anneal(anneal_kws): + device = anneal_kws['device'] + start = torch.tensor(anneal_kws['start'], device=device) + finish = torch.tensor(anneal_kws['finish'], device=device) + rate = torch.tensor(anneal_kws['rate'], device=device) + return lambda step: finish - (finish - start)*torch.pow(rate, torch.tensor(step, dtype=torch.float, device=device)) + + +def sigmoid_anneal(anneal_kws): + device = anneal_kws['device'] + start = torch.tensor(anneal_kws['start'], device=device) + finish = torch.tensor(anneal_kws['finish'], device=device) + center_step = torch.tensor(anneal_kws['center_step'], device=device, dtype=torch.float) + steps_lo_to_hi = torch.tensor(anneal_kws['steps_lo_to_hi'], device=device, dtype=torch.float) + return lambda step: start + (finish - start)*torch.sigmoid((torch.tensor(float(step), device=device) - center_step) * (1./steps_lo_to_hi)) + + +class CustomLR(torch.optim.lr_scheduler.LambdaLR): + def __init__(self, optimizer, lr_lambda, last_epoch=-1): + super(CustomLR, self).__init__(optimizer, lr_lambda, last_epoch) + + def get_lr(self): + return [lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] + + +def mutual_inf_mc(x_dist): + dist = x_dist.__class__ + H_y = dist(probs=x_dist.probs.mean(dim=0)).entropy() + return (H_y - x_dist.entropy().mean(dim=0)).sum() + + +def run_lstm_on_variable_length_seqs(lstm_module, original_seqs, lower_indices=None, upper_indices=None, total_length=None): + bs, tf = original_seqs.shape[:2] + if lower_indices is None: + lower_indices = torch.zeros(bs, dtype=torch.int) + if upper_indices is None: + upper_indices = torch.ones(bs, dtype=torch.int) * (tf - 1) + if total_length is None: + total_length = max(upper_indices) + 1 + # This is done so that we can just pass in self.prediction_timesteps + # (which we want to INCLUDE, so this will exclude the next timestep). + inclusive_break_indices = upper_indices + 1 + + pad_list = list() + for i, seq_len in enumerate(inclusive_break_indices): + pad_list.append(original_seqs[i, lower_indices[i]:seq_len]) + + packed_seqs = rnn.pack_sequence(pad_list, enforce_sorted=False) + packed_output, (h_n, c_n) = lstm_module(packed_seqs) + output, _ = rnn.pad_packed_sequence(packed_output, + batch_first=True, + total_length=total_length) + + return output, (h_n, c_n) + + +def extract_subtensor_per_batch_element(tensor, indices): + batch_idxs = torch.arange(start=0, end=len(indices)) + + batch_idxs = batch_idxs[~torch.isnan(indices)] + indices = indices[~torch.isnan(indices)] + if indices.size == 0: + return None + else: + indices = indices.long() + if tensor.is_cuda: + batch_idxs = batch_idxs.to(tensor.get_device()) + indices = indices.to(tensor.get_device()) + return tensor[batch_idxs, indices] + + +def unpack_RNN_state(state_tuple): + # PyTorch returned LSTM states have 3 dims: + # (num_layers * num_directions, batch, hidden_size) + + state = torch.cat(state_tuple, dim=0).permute(1, 0, 2) + # Now state is (batch, 2 * num_layers * num_directions, hidden_size) + + state_size = state.size() + return torch.reshape(state, (-1, state_size[1] * state_size[2])) + + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +# using wonder's beautiful simplification: +# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427 +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + return functools.reduce(_getattr, [obj] + attr.split('.')) diff --git a/diffstack/modules/predictors/trajectron_utils/model/online/__init__.py b/diffstack/modules/predictors/trajectron_utils/model/online/__init__.py new file mode 100644 index 0000000..a1c9070 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/online/__init__.py @@ -0,0 +1,2 @@ +from .online_trajectron import OnlineTrajectron +from .online_mgcvae import OnlineMultimodalGenerativeCVAE diff --git a/diffstack/modules/predictors/trajectron_utils/model/online/online_mgcvae.py b/diffstack/modules/predictors/trajectron_utils/model/online/online_mgcvae.py new file mode 100644 index 0000000..a843244 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/online/online_mgcvae.py @@ -0,0 +1,419 @@ +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from collections import defaultdict, Counter +from model.components import * +from model.model_utils import * +from model.dataset import get_relative_robot_traj +import model.dynamics as dynamic_module +from model.mgcvae import MultimodalGenerativeCVAE +from environment.scene_graph import DirectedEdge +from environment.node_type import NodeType + + +class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): + def __init__(self, + env, + node, + model_registrar, + hyperparams, + device): + if len(env.scenes) != 1: + raise ValueError("Passed in Environment has number of scenes != 1") + + super(OnlineMultimodalGenerativeCVAE, self).__init__(env, node.type, model_registrar, hyperparams, device, edge_types=[]) + + self.node = node + self.robot = env.scenes[0].robot + + self.scene_graph = None + + self.curr_hidden_states = dict() + self.edge_types = Counter() + + self.create_initial_graphical_model() + + def create_initial_graphical_model(self): + """ + Creates or queries all trainable components. + + :return: None + """ + self.clear_submodules() + + ############################ + # Everything but Edges # + ############################ + self.create_node_models() + + for name, module in self.node_modules.items(): + module.to(self.device) + + def update_graph(self, new_scene_graph, new_neighbors, removed_neighbors): + self.scene_graph = new_scene_graph + + if self.node in new_neighbors: + for edge_type, new_neighbor_nodes in new_neighbors[self.node].items(): + self.add_edge_model(edge_type) + self.edge_types += Counter({edge_type: len(new_neighbor_nodes)}) + + if self.node in removed_neighbors: + for edge_type, removed_neighbor_nodes in removed_neighbors[self.node].items(): + self.remove_edge_model(edge_type) + self.edge_types -= Counter({edge_type: len(removed_neighbor_nodes)}) + + def get_edge_to(self, other_node): + return DirectedEdge(self.node, other_node) + + def add_edge_model(self, edge_type): + if self.hyperparams['edge_encoding']: + if edge_type + '/edge_encoder' not in self.node_modules: + neighbor_state_length = int( + np.sum([len(entity_dims) for entity_dims in + self.state[self._get_other_node_type_from_edge(edge_type)].values()])) + if self.hyperparams['edge_state_combine_method'] == 'pointnet': + self.add_submodule(edge_type + '/pointnet_encoder', + model_if_absent=nn.Sequential( + nn.Linear(self.state_length, 2 * self.state_length), + nn.ReLU(), + nn.Linear(2 * self.state_length, 2 * self.state_length), + nn.ReLU())) + + edge_encoder_input_size = 2 * self.state_length + self.state_length + + elif self.hyperparams['edge_state_combine_method'] == 'attention': + self.add_submodule(self.node.type + '/edge_attention_combine', + model_if_absent=TemporallyBatchedAdditiveAttention( + encoder_hidden_state_dim=self.state_length, + decoder_hidden_state_dim=self.state_length)) + edge_encoder_input_size = self.state_length + neighbor_state_length + + else: + edge_encoder_input_size = self.state_length + neighbor_state_length + + self.add_submodule(edge_type + '/edge_encoder', + model_if_absent=nn.LSTM(input_size=edge_encoder_input_size, + hidden_size=self.hyperparams['enc_rnn_dim_edge'], + batch_first=True)) + + def _get_other_node_type_from_edge(self, edge_type_str): + n2_type_str = edge_type_str.split('->')[1] + return NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1) + + def _get_edge_type_from_str(self, edge_type_str): + n1_type_str, n2_type_str = edge_type_str.split('->') + return (NodeType(n1_type_str, self.env.node_type_list.index(n1_type_str) + 1), + NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1)) + + def remove_edge_model(self, edge_type): + if self.hyperparams['edge_encoding']: + if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0: + del self.node_modules[edge_type + '/edge_encoder'] + + def obtain_encoded_tensors(self, + mode, + inputs, + inputs_st, + inputs_np, + robot_present_and_future, + maps): + x, x_r_t, y_r = None, None, None + batch_size = 1 + + our_inputs = inputs[self.node] + our_inputs_st = inputs_st[self.node] + + initial_dynamics = dict() + initial_dynamics['pos'] = our_inputs[:, 0:2] # TODO: Generalize + initial_dynamics['vel'] = our_inputs[:, 2:4] # TODO: Generalize + self.dynamic.set_initial_condition(initial_dynamics) + + ######################################### + # Provide basic information to encoders # + ######################################### + if self.hyperparams['incl_robot_node'] and self.robot is not None: + node_state = torch.zeros((robot_present_and_future.shape[-1], ), dtype=torch.float, device=self.device) + node_state[:our_inputs.shape[1]] = our_inputs[0] + robot_present_and_future_st = get_relative_robot_traj(self.env, self.state, + node_state, robot_present_and_future, + self.node.type, self.robot.type) + x_r_t = robot_present_and_future_st[..., 0, :] + y_r = robot_present_and_future_st[..., 1:, :] + + ################## + # Encode History # + ################## + node_history_encoded = self.encode_node_history(our_inputs_st) + + ############################## + # Encode Node Edges per Type # + ############################## + total_edge_influence = None + if self.hyperparams['edge_encoding']: + node_edges_encoded = list() + for edge_type in self.edge_types: + connected_nodes_batched = list() + edge_masks_batched = list() + + # We get all nodes which are connected to the current node for the current timestep + connected_nodes_batched.append(self.scene_graph.get_neighbors(self.node, + self._get_other_node_type_from_edge( + edge_type))) + + if self.hyperparams['dynamic_edges'] == 'yes': + # We get the edge masks for the current node at the current timestep + edge_masks_for_node = self.scene_graph.get_edge_scaling(self.node) + edge_masks_batched.append(torch.tensor(edge_masks_for_node, dtype=torch.float, device=self.device)) + + # Encode edges for given edge type + encoded_edges_type = self.encode_edge(inputs, + inputs_st, + inputs_np, + edge_type, + connected_nodes_batched, + edge_masks_batched) + node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim] + + ##################### + # Encode Node Edges # + ##################### + total_edge_influence = self.encode_total_edge_influence(mode, + node_edges_encoded, + node_history_encoded, + batch_size) + + self.TD = {'node_history_encoded': node_history_encoded, + 'total_edge_influence': total_edge_influence} + + ################ + # Map Encoding # + ################ + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + if self.node not in maps: + # This means the node was removed (it is only being kept around because of the edge removal filter). + me_params = self.hyperparams['map_encoder'][self.node_type] + self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])) + else: + encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1., + (mode == ModeKeys.TRAIN)) + do = self.hyperparams['map_encoder'][self.node_type]['dropout'] + encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN)) + self.TD['encoded_map'] = encoded_map + + ###################################### + # Concatenate Encoder Outputs into x # + ###################################### + return self.create_encoder_rep(mode, self.TD, x_r_t, y_r) + + def create_encoder_rep(self, mode, + TD, + robot_present_st, + robot_future_st): + # Unpacking TD + node_history_encoded = TD['node_history_encoded'] + if self.hyperparams['edge_encoding']: + total_edge_influence = TD['total_edge_influence'] + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + encoded_map = TD['encoded_map'] + + if (self.hyperparams['incl_robot_node'] + and self.robot is not None + and robot_future_st is not None + and robot_present_st is not None): + robot_future_encoder = self.encode_robot_future(mode, robot_present_st, robot_future_st) + + # Tiling for multiple samples + # This tiling is done because: + # a) we must consider the prediction case where there are many candidate robot future actions, + # b) the edge and history encoders are all the same regardless of which candidate future robot action + # we're evaluating. + node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1) + if self.hyperparams['edge_encoding']: + total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1) + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + encoded_map = TD['encoded_map'].repeat(robot_future_st.size()[0], 1) + + elif self.hyperparams['incl_robot_node'] and self.robot is not None: + # Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends). + robot_future_encoder = torch.zeros([1, 4 * self.hyperparams['enc_rnn_dim_future']], device=self.device) + + x_concat_list = list() + + # Every node has an edge-influence encoder (which could just be zero). + if self.hyperparams['edge_encoding']: + x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim] + + # Every node has a history encoder. + x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history] + + if self.hyperparams['incl_robot_node'] and self.robot is not None: + x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history] + + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + x_concat_list.append(encoded_map) # [bs/nbs, CNN output size] + + return torch.cat(x_concat_list, dim=1) + + def encode_node_history(self, inputs_st): + new_state = torch.unsqueeze(inputs_st, dim=1) # [bs, 1, state_dim] + if self.node.type + '/node_history_encoder' not in self.curr_hidden_states: + outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[ + self.node.type + '/node_history_encoder'](new_state) + else: + outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[ + self.node.type + '/node_history_encoder'](new_state, self.curr_hidden_states[ + self.node.type + '/node_history_encoder']) + + return outputs[:, 0, :] + + def encode_edge(self, inputs, inputs_st, inputs_np, edge_type, connected_nodes, edge_masks): + edge_type_tuple = self._get_edge_type_from_str(edge_type) + edge_states_list = list() # list of [#of neighbors, max_ht, state_dim] + neighbor_states = list() + + orig_rel_state = inputs[self.node].cpu().numpy() + for node in connected_nodes[0]: + neighbor_state_np = inputs_np[node] + + # Make State relative to node + _, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type) + std[0:2] = self.env.attention_radius[edge_type_tuple] + + # TODO: This all makes the unsafe assumption that the first n dims + # refer to the same quantities even for different agent types! + equal_dims = np.min((neighbor_state_np.shape[-1], orig_rel_state.shape[-1])) + rel_state = np.zeros_like(neighbor_state_np) + rel_state[..., :equal_dims] = orig_rel_state[..., :equal_dims] + neighbor_state_np_st = self.env.standardize(neighbor_state_np, + self.state[node.type], + node_type=node.type, + mean=rel_state, + std=std) + + neighbor_state = torch.tensor(neighbor_state_np_st).float().to(self.device) + neighbor_states.append(neighbor_state) + + if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary? + neighbor_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1]].values()])) + edge_states_list.append(torch.zeros((1, 1, neighbor_state_length), device=self.device)) + else: + edge_states_list.append(torch.stack(neighbor_states, dim=0)) + + if self.hyperparams['edge_state_combine_method'] == 'sum': + # Used in Structural-RNN to combine edges as well. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_mask in edge_masks: + op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_mask, dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + elif self.hyperparams['edge_state_combine_method'] == 'max': + # Used in NLP, e.g. max over word embeddings in a sentence. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_mask in edge_masks: + op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_mask, dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + elif self.hyperparams['edge_state_combine_method'] == 'mean': + # Used in NLP, e.g. mean over word embeddings in a sentence. + op_applied_edge_states_list = list() + for neighbors_state in edge_states_list: + op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0)) + combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0) + if self.hyperparams['dynamic_edges'] == 'yes': + # Should now be (bs, time, 1) + op_applied_edge_mask_list = list() + for edge_mask in edge_masks: + op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_mask, dim=0, keepdim=True), max=1.)) + combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0) + + joint_history = torch.cat([combined_neighbors, torch.unsqueeze(inputs_st[self.node], dim=0)], dim=-1) + + if edge_type + '/edge_encoder' not in self.curr_hidden_states: + outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[ + edge_type + '/edge_encoder'](joint_history) + else: + outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[ + edge_type + '/edge_encoder'](joint_history, self.curr_hidden_states[edge_type + '/edge_encoder']) + + if self.hyperparams['dynamic_edges'] == 'yes': + return outputs[:, 0, :] * combined_edge_masks + else: + return outputs[:, 0, :] # [bs, enc_rnn_dim] + + def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None, maps=None): + # Always predicting with the online model. + mode = ModeKeys.PREDICT + + self.x = self.obtain_encoded_tensors(mode, + inputs, + inputs_st, + inputs_np, + robot_present_and_future, + maps) + self.n_s_t0 = inputs_st[self.node] + + self.latent.p_dist = self.p_z_x(mode, self.x) + + # robot_future_st is optional here since you can use the same one from encoder_forward, + # but if it's given then we'll re-run that part of the model (if the node is adjacent to the robot). + def decoder_forward(self, prediction_horizon, + num_samples, + robot_present_and_future=None, + z_mode=False, + gmm_mode=False, + full_dist=False, + all_z_sep=False): + # Always predicting with the online model. + mode = ModeKeys.PREDICT + + x_nr_t, y_r = None, None + if (self.hyperparams['incl_robot_node'] + and self.robot is not None + and robot_present_and_future is not None): + our_inputs = torch.tensor(self.node.get(np.array([self.node.last_timestep]), + self.state[self.node.type], + padding=0.0), + dtype=torch.float, + device=self.device) + + node_state = torch.zeros((robot_present_and_future.shape[-1], ), dtype=torch.float, device=self.device) + node_state[:our_inputs.shape[1]] = our_inputs[0] + + robot_present_and_future_st = get_relative_robot_traj(self.env, self.state, + node_state, robot_present_and_future, + self.node.type, self.robot.type) + x_nr_t = robot_present_and_future_st[..., 0, :] + y_r = robot_present_and_future_st[..., 1:, :] + self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r) + self.latent.p_dist = self.p_z_x(mode, self.x) + + # Making sure n_s_t0 has the same batch size as x_nr_t + self.n_s_t0 = self.n_s_t0[[0]].repeat(x_nr_t.size()[0], 1) + + z, num_samples, num_components = self.latent.sample_p(num_samples, + mode, + most_likely_z=z_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) + + y_dist, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, self.n_s_t0, z, + prediction_horizon, + num_samples, + num_components, + gmm_mode) + + return y_dist, our_sampled_future diff --git a/diffstack/modules/predictors/trajectron_utils/model/online/online_trajectron.py b/diffstack/modules/predictors/trajectron_utils/model/online/online_trajectron.py new file mode 100644 index 0000000..cb6e250 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/model/online/online_trajectron.py @@ -0,0 +1,310 @@ +import torch +import numpy as np +from collections import Counter +from model.trajectron import Trajectron +from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE +from model.model_utils import ModeKeys +from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of + + +class OnlineTrajectron(Trajectron): + def __init__(self, model_registrar, + hyperparams, device): + super(OnlineTrajectron, self).__init__(model_registrar=model_registrar, + hyperparams=hyperparams, + log_writer=False, + device=device) + + # We don't really care that this is a nn.ModuleDict, since + # we want to index it by node object anyways. + del self.node_models_dict + self.node_models_dict = dict() + + self.node_data = dict() + self.scene_graph = None + self.RING_CAPACITY = max(len(self.hyperparams['edge_removal_filter']), + len(self.hyperparams['edge_addition_filter']), + self.hyperparams['maximum_history_length']) + 1 + self.rel_states = dict() + self.removed_nodes = Counter() + + def __repr__(self): + return f"OnlineTrajectron(# nodes: {len(self.nodes)}, device: {self.device}, hyperparameters: {str(self.hyperparams)}) " + + def _add_node_model(self, node): + if node in self.nodes: + raise ValueError('%s was already added to this graph!' % str(node)) + + self.nodes.add(node) + self.node_models_dict[node] = OnlineMultimodalGenerativeCVAE(self.env, + node, + self.model_registrar, + self.hyperparams, + self.device) + + def update_removed_nodes(self): + for node in list(self.removed_nodes.keys()): + if self.removed_nodes[node] >= len(self.hyperparams['edge_removal_filter']): + del self.node_data[node] + del self.removed_nodes[node] + + def _remove_node_model(self, node): + if node not in self.nodes: + raise ValueError('%s is not in this graph!' % str(node)) + + self.nodes.remove(node) + del self.node_models_dict[node] + + def set_environment(self, env, init_timestep=0): + self.env = env + self.scene_graph = SceneGraph(edge_radius=self.env.attention_radius) + self.nodes.clear() + self.node_data.clear() + self.node_models_dict.clear() + + # Fast-forwarding ourselves to the initial timestep, without running any of the underlying models. + for timestep in range(init_timestep + 1): + self.incremental_forward(self.env.scenes[0].get_clipped_input_dict(timestep, self.hyperparams['state']), + maps=None, run_models=False) + + def incremental_forward(self, new_inputs_dict, + maps, + prediction_horizon=0, + num_samples=0, + robot_present_and_future=None, + z_mode=False, + gmm_mode=False, + full_dist=False, + all_z_sep=False, + run_models=True): + # The way this function works is by appending the new datapoints to the + # ends of each of the LSTMs in the graph. Then, we recalculate the + # encoder's output vector h_x and feed that into the decoder to sample new outputs. + mode = ModeKeys.PREDICT + + # No grad since we're predicting always, as evidenced by the line above. + with torch.no_grad(): + for node, new_input in new_inputs_dict.items(): + if node not in self.node_data: + self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY, + dtype=(float, sum(len(self.state[node.type][k]) + for k in self.state[node.type]))) + self.node_data[node].append(new_input) + + if node in self.removed_nodes: + del self.removed_nodes[node] + + # Nodes in self.node_data that aren't in new_inputs_dict were just removed. + newly_removed_nodes = (set(self.node_data.keys()) - set(self.removed_nodes.keys())) - set( + new_inputs_dict.keys()) + + # We update self.removed_nodes with the newly removed nodes as well as all existing removed nodes to get + # the time since their last removal increased by one. + self.removed_nodes.update(newly_removed_nodes | set(self.removed_nodes.keys())) + + # For any nodes that are older than the length of the edge_removal_filter, we can safely clear their data. + self.update_removed_nodes() + + # Any remaining removed nodes that aren't yet old enough for data clearing simply have NaNs appended so + # that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role + # (the NaNs get converted to zeros later on). + for node in self.removed_nodes: + self.node_data[node].append(np.full((1, self.node_data[node].shape[1]), np.nan)) + + for node in self.node_data: + node.overwrite_data(self.node_data[node], None, + forward_in_time_on_next_overwrite=(self.node_data[node].shape[0] + == self.RING_CAPACITY)) + + temp_scene_dict = {k: v[:, 0:2] for k, v in self.node_data.items()} + if not temp_scene_dict: + new_scene_graph = SceneGraph(edge_radius=self.env.attention_radius) + else: + new_scene_graph = TemporalSceneGraph.create_from_temp_scene_dict( + temp_scene_dict, + self.env.attention_radius, + duration=self.RING_CAPACITY, + edge_addition_filter=self.hyperparams['edge_addition_filter'], + edge_removal_filter=self.hyperparams['edge_removal_filter'], + online=True).to_scene_graph(t=self.RING_CAPACITY - 1) + + if self.hyperparams['dynamic_edges'] == 'yes': + new_nodes, removed_nodes, new_neighbors, removed_neighbors = new_scene_graph - self.scene_graph + + # Aside from updating the scene graph, this for loop updates the graph model + # structure of all affected nodes. + not_removed_nodes = [node for node in self.nodes if node not in removed_nodes] + self.scene_graph = new_scene_graph + for node in not_removed_nodes: + self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors) + + # These next 2 for loops add or remove entire node models. + for node in new_nodes: + if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys(): + # Only deal with Models for NodeTypes we want to predict + continue + + self._add_node_model(node) + self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors) + + for node in removed_nodes: + if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys(): + continue + + self._remove_node_model(node) + + # This actually updates the node models with the newly observed data. + if run_models: + inputs = dict() + inputs_st = dict() + inputs_np = dict() + + iter_list = list(self.node_models_dict.keys()) + [node for node in new_inputs_dict + if node.type not in self.pred_state.keys()] + if self.env.scenes[0].robot is not None: + iter_list.append(self.env.scenes[0].robot) + + for node in iter_list: + input_np = node.get(np.array([node.last_timestep, node.last_timestep]), self.state[node.type]) + + _, std = self.env.get_standardize_params(self.state[node.type.name], node.type) + std[0:2] = self.env.attention_radius[(node.type, node.type)] + rel_state = np.zeros_like(input_np) + rel_state[:, 0:2] = input_np[:, 0:2] + input_st = self.env.standardize(input_np, + self.state[node.type.name], + node.type, + mean=rel_state) + self.rel_states[node] = rel_state + + # Converting NaNs to zeros. + input_np[np.isnan(input_np)] = 0 + input_st[np.isnan(input_st)] = 0 + + # Convert to torch tensors + inputs[node] = torch.tensor(input_np, dtype=torch.float, device=self.device) + inputs_st[node] = torch.tensor(input_st, dtype=torch.float, device=self.device) + inputs_np[node] = input_np + + # We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size. + if (self.hyperparams['incl_robot_node'] + and self.env.scenes[0].robot is not None + and robot_present_and_future is not None): + if len(robot_present_and_future.shape) == 2: + robot_present_and_future = robot_present_and_future[np.newaxis, :] + + assert robot_present_and_future.shape[1] == prediction_horizon + 1 + robot_present_and_future = torch.tensor(robot_present_and_future, + dtype=torch.float, device=self.device) + + for node in self.node_models_dict: + self.node_models_dict[node].encoder_forward(inputs, + inputs_st, + inputs_np, + robot_present_and_future, + maps) + + # If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all, + # just update the encoder LSTMs. + if prediction_horizon == 0 or num_samples == 0: + return + + return self.sample_model(prediction_horizon, + num_samples, + robot_present_and_future=robot_present_and_future, + z_mode=z_mode, + gmm_mode=gmm_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) + + def _run_decoder(self, node, + num_predicted_timesteps, + num_samples, + robot_present_and_future=None, + z_mode=False, + gmm_mode=False, + full_dist=False, + all_z_sep=False): + model = self.node_models_dict[node] + prediction_dist, predictions_uns = model.decoder_forward(num_predicted_timesteps, + num_samples, + robot_present_and_future=robot_present_and_future, + z_mode=z_mode, + gmm_mode=gmm_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) + + # predictions_np = predictions_uns.cpu().detach().numpy() + + # Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2) + return prediction_dist#, np.transpose(predictions_np, (1, 0, 2, 3)) + + def sample_model(self, num_predicted_timesteps, + num_samples, + robot_present_and_future=None, + z_mode=False, + gmm_mode=False, + full_dist=False, + all_z_sep=False): + # Just start from the encoder output (minus the + # robot future) and get num_samples of + # num_predicted_timesteps-length trajectories. + if num_predicted_timesteps == 0 or num_samples == 0: + return + + mode = ModeKeys.PREDICT + + # We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size. + if self.hyperparams['incl_robot_node'] and self.env.scenes[ + 0].robot is not None and robot_present_and_future is not None: + if len(robot_present_and_future.shape) == 2: + robot_present_and_future = robot_present_and_future[np.newaxis, :] + + assert robot_present_and_future.shape[1] == num_predicted_timesteps + 1 + + # No grad since we're predicting always, as evidenced by the line above. + with torch.no_grad(): + predictions_dict = dict() + prediction_dists = dict() + for node in set(self.nodes) - set(self.removed_nodes.keys()): + if node.is_robot: + continue + + prediction_dists[node] = self._run_decoder(node, num_predicted_timesteps, + num_samples, + robot_present_and_future, + z_mode, + gmm_mode, + full_dist, + all_z_sep) + + return prediction_dists#, predictions_dict + + def forward(self, init_env, + init_timestep, + input_dicts, # After the initial environment + num_predicted_timesteps, + num_samples, + robot_present_and_future=None, + z_mode=False, + gmm_mode=False, + full_dist=False, + all_z_sep=False): + # This is the standard forward prediction function, + # if you have some historical data and just want to + # predict forward some number of timesteps. + + # Setting us back to the initial scene graph we had. + self.set_environment(init_env, init_timestep) + + # Looping through and applying updates to the model. + for i in range(len(input_dicts)): + self.incremental_forward(input_dicts[i]) + + return self.sample_model(num_predicted_timesteps, + num_samples, + robot_present_and_future=robot_present_and_future, + z_mode=z_mode, + gmm_mode=gmm_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) diff --git a/diffstack/modules/predictors/trajectron_utils/node.py b/diffstack/modules/predictors/trajectron_utils/node.py new file mode 100644 index 0000000..22f202e --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/node.py @@ -0,0 +1,265 @@ +import random +import numpy as np +import pandas as pd +from diffstack.modules.predictors.trajectron_utils.environment import DoubleHeaderNumpyArray +from ncls import NCLS +from typing import Tuple + + +class Node(object): + def __init__(self, node_type, node_id, data, length=None, width=None, height=None, first_timestep=0, + is_robot=False, description="", frequency_multiplier=1, non_aug_node=None, extra_data=None): + self.type = node_type + self.id = node_id + self.length = length + self.width = width + self.height = height + self.first_timestep = first_timestep + self.non_aug_node = non_aug_node + + if data is not None: + if isinstance(data, pd.DataFrame): + self.data = DoubleHeaderNumpyArray(data.values, list(data.columns)) + elif isinstance(data, DoubleHeaderNumpyArray): + self.data = data + else: + self.data = None + + self.extra_data = extra_data + + self.is_robot = is_robot + self._last_timestep = None + self.description = description + self.frequency_multiplier = frequency_multiplier + + self.forward_in_time_on_next_override = False + + def __eq__(self, other): + return ((isinstance(other, self.__class__) + or isinstance(self, other.__class__)) + and self.id == other.id + and self.type == other.type) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.type, self.id)) + + def __repr__(self): + return '/'.join([self.type.name, self.id]) + + def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False): + """ + This function hard overwrites the data matrix. When using it you have to make sure that the columns + in the new data matrix correspond to the old structure. As well as setting first_timestep. + + :param data: New data matrix + :param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased. + :return: None + """ + if header is None: + self.data.data = data + else: + self.data = DoubleHeaderNumpyArray(data, header) + + self._last_timestep = None + if self.forward_in_time_on_next_override: + self.first_timestep += 1 + self.forward_in_time_on_next_override = forward_in_time_on_next_overwrite + + def scene_ts_to_node_ts(self, scene_ts) -> Tuple[np.ndarray, int, int]: + """ + Transforms timestamp from scene into timeframe of node data. + + :param scene_ts: Scene timesteps + :return: ts: Transformed timesteps, paddingl: Number of timesteps in scene range which are not available in + node data before data is available. paddingu: Number of timesteps in scene range which are not + available in node data after data is available. + """ + paddingl = (self.first_timestep - scene_ts[0]).clip(0) + paddingu = (scene_ts[1] - self.last_timestep).clip(0) + ts = np.array(scene_ts).clip(min=self.first_timestep, max=self.last_timestep) - self.first_timestep + return ts, paddingl, paddingu + + def history_points_at(self, ts) -> int: + """ + Number of history points in trajectory. Timestep is exclusive. + + :param ts: Scene timestep where the number of history points are queried. + :return: Number of history timesteps. + """ + return ts - self.first_timestep + + def get(self, tr_scene, state, padding=np.nan) -> np.ndarray: + """ + Returns a time range of multiple properties of the node. + + :param tr_scene: The timestep range (inklusive). + :param state: The state description for which the properties are returned. + :param padding: The value which should be used for padding if not enough information is available. + :return: Array of node property values. + """ + if tr_scene.size == 1: + tr_scene = np.array([tr_scene[0], tr_scene[0]]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + data_array = self.data[tr[0]:tr[1] + 1, state] + padded_data_array = np.full((length, data_array.shape[1]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + return padded_data_array + + def get_lane_points(self, tr_scene, padding=np.nan, num_lane_points=16) -> np.ndarray: + """ + :param tr_scene: The timestep range (inklusive). + """ + if tr_scene.size == 1: + tr_scene = np.array([tr_scene[0], tr_scene[0]]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + data_array = self.extra_data['lane_points'][tr[0]:tr[1] + 1] + padded_data_array = np.full((length, data_array.shape[1], data_array.shape[2]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + # extend to fixed num lane points + if num_lane_points is not None: + if padded_data_array.shape[1] == 0: + padded_data_array = np.full((length, num_lane_points, 3), fill_value=padding) + elif padded_data_array.shape[1] < num_lane_points: + pad = np.repeat(padded_data_array[:, -1:], num_lane_points-padded_data_array.shape[1], axis=1) + padded_data_array = np.concatenate((padded_data_array, pad), axis=1) + else: + padded_data_array = padded_data_array[:, :num_lane_points] + return padded_data_array + + @property + def timesteps(self) -> int: + """ + Number of available timesteps for node. + + :return: Number of available timesteps. + """ + return self.data.shape[0] + + @property + def last_timestep(self) -> int: + """ + Nodes last timestep in the Scene. + + :return: Nodes last timestep. + """ + if self._last_timestep is None: + self._last_timestep = self.first_timestep + self.timesteps - 1 + return self._last_timestep + + +class MultiNode(Node): + def __init__(self, node_type, node_id, nodes_list, is_robot=False): + super(MultiNode, self).__init__(node_type, node_id, data=None, is_robot=is_robot) + self.nodes_list = nodes_list + for node in self.nodes_list: + node.is_robot = is_robot + + self.first_timestep = min(node.first_timestep for node in self.nodes_list) + self._last_timestep = max(node.last_timestep for node in self.nodes_list) + + starts = np.array([node.first_timestep for node in self.nodes_list], dtype=np.int64) + ends = np.array([node.last_timestep for node in self.nodes_list], dtype=np.int64) + ids = np.arange(len(self.nodes_list), dtype=np.int64) + self.interval_tree = NCLS(starts, ends, ids) + + @staticmethod + def find_non_overlapping_nodes(nodes_list, min_timesteps=1) -> list: + """ + Greedily finds a set of non-overlapping nodes in the provided scene. + + :return: A list of non-overlapping nodes. + """ + non_overlapping_nodes = list() + nodes = sorted(nodes_list, key=lambda n: n.last_timestep) + current_time = 0 + for node in nodes: + if node.first_timestep >= current_time and node.timesteps >= min_timesteps: + # Include the node + non_overlapping_nodes.append(node) + current_time = node.last_timestep + + return non_overlapping_nodes + + def get_node_at_timesteps(self, scene_ts) -> Node: + possible_node_ranges = list(self.interval_tree.find_overlap(scene_ts[0], scene_ts[1] + 1)) + if not possible_node_ranges: + return Node(node_type=self.type, + node_id='EMPTY', + data=self.nodes_list[0].data * np.nan, + is_robot=self.is_robot) + + node_idx = random.choice(possible_node_ranges)[2] + return self.nodes_list[node_idx] + + def scene_ts_to_node_ts(self, scene_ts) -> Tuple[Node, np.ndarray, int, int]: + """ + Transforms timestamp from scene into timeframe of node data. + + :param scene_ts: Scene timesteps + :return: ts: Transformed timesteps, paddingl: Number of timesteps in scene range which are not available in + node data before data is available. paddingu: Number of timesteps in scene range which are not + available in node data after data is available. + """ + possible_node_ranges = list(self.interval_tree.find_overlap(scene_ts[0], scene_ts[1] + 1)) + if not possible_node_ranges: + return None, None, None, None + + node_idx = random.choice(possible_node_ranges)[2] + node = self.nodes_list[node_idx] + + paddingl = (node.first_timestep - scene_ts[0]).clip(0) + paddingu = (scene_ts[1] - node.last_timestep).clip(0) + ts = np.array(scene_ts).clip(min=node.first_timestep, max=node.last_timestep) - node.first_timestep + return node, ts, paddingl, paddingu + + def get(self, tr_scene, state, padding=np.nan) -> np.ndarray: + if tr_scene.size == 1: + tr_scene = np.array([tr_scene, tr_scene]) + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + + node, tr, paddingl, paddingu = self.scene_ts_to_node_ts(tr_scene) + if node is None: + state_length = sum([len(entity_dims) for entity_dims in state.values()]) + return np.full((length, state_length), fill_value=padding) + + data_array = node.data[tr[0]:tr[1] + 1, state] + padded_data_array = np.full((length, data_array.shape[1]), fill_value=padding) + padded_data_array[paddingl:length - paddingu] = data_array + return padded_data_array + + def get_all(self, tr_scene, state, padding=np.nan) -> np.ndarray: + # Assumption here is that the user is asking for all of the data in this MultiNode and to return it within a + # full scene-sized output array. + assert tr_scene.size == 2 and tr_scene[0] == 0 and self.last_timestep <= tr_scene[1] + length = tr_scene[1] - tr_scene[0] + 1 # tr is inclusive + state_length = sum([len(entity_dims) for entity_dims in state.values()]) + padded_data_array = np.full((length, state_length), fill_value=padding) + for node in self.nodes_list: + padded_data_array[node.first_timestep:node.last_timestep + 1] = node.data[:, state] + + return padded_data_array + + def history_points_at(self, ts) -> int: + """ + Number of history points in trajectory. Timestep is exclusive. + + :param ts: Scene timestep where the number of history points are queried. + :return: Number of history timesteps. + """ + node_idx = next(self.interval_tree.find_overlap(ts, ts + 1))[2] + node = self.nodes_list[node_idx] + return ts - node.first_timestep + + @property + def timesteps(self) -> int: + """ + Number of available timesteps for node. + + :return: Number of available timesteps. + """ + return self._last_timestep - self.first_timestep + 1 diff --git a/diffstack/modules/predictors/trajectron_utils/node_type.py b/diffstack/modules/predictors/trajectron_utils/node_type.py new file mode 100644 index 0000000..20b36da --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/node_type.py @@ -0,0 +1,36 @@ +class NodeType(object): + def __init__(self, name, value): + self.name = name + self.value = value + + def __repr__(self): + return self.name + + def __eq__(self, other): + if type(other) == str and self.name == other: + return True + else: + # Only check if class names match, so relative and absolute imports will be treated equal. + return (isinstance(other, self.__class__) or (other.__class__.__name__ == self.__class__.__name__)) and self.name == other.name + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.name) + + def __add__(self, other): + return self.name + other + + +class NodeTypeEnum(list): + def __init__(self, node_type_list): + self.node_type_list = node_type_list + node_types = [NodeType(name, node_type_list.index(name) + 1) for name in node_type_list] + super().__init__(node_types) + + def __getattr__(self, name): + if not name.startswith('_') and name in object.__getattribute__(self, "node_type_list"): + return self[object.__getattribute__(self, "node_type_list").index(name)] + else: + return object.__getattribute__(self, name) diff --git a/diffstack/modules/predictors/trajectron_utils/trajectron/__init__.py b/diffstack/modules/predictors/trajectron_utils/trajectron/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/modules/predictors/trajectron_utils/trajectron/trajectron.py b/diffstack/modules/predictors/trajectron_utils/trajectron/trajectron.py new file mode 100644 index 0000000..b2b1990 --- /dev/null +++ b/diffstack/modules/predictors/trajectron_utils/trajectron/trajectron.py @@ -0,0 +1,6 @@ +import torch +from torch import nn + + +class Trajectron(nn.Module): + pass \ No newline at end of file diff --git a/diffstack/scripts/generate_config_templates.py b/diffstack/scripts/generate_config_templates.py new file mode 100644 index 0000000..bf927b6 --- /dev/null +++ b/diffstack/scripts/generate_config_templates.py @@ -0,0 +1,19 @@ +""" +Helpful script to generate example config files for each algorithm. These should be re-generated +when new config options are added, or when default settings in the config classes are modified. +""" +import os + +import diffstack +from diffstack.configs.registry import EXP_CONFIG_REGISTRY + + +def main(): + # store template config jsons in this directory + target_dir = os.path.join(diffstack.__path__[0], "../config/templates/") + for name, cfg in EXP_CONFIG_REGISTRY.items(): + cfg.dump(filename=os.path.join(target_dir, name + ".json")) + + +if __name__ == "__main__": + main() diff --git a/diffstack/scripts/train_pl.py b/diffstack/scripts/train_pl.py new file mode 100644 index 0000000..9a77eaa --- /dev/null +++ b/diffstack/scripts/train_pl.py @@ -0,0 +1,582 @@ +import argparse +import sys +import os +import socket +import torch + +import wandb +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.profilers import PyTorchProfiler + +from diffstack.utils.log_utils import PrintLogger +from diffstack.utils.experiment_utils import get_checkpoint +import diffstack.utils.train_utils as TrainUtils +from diffstack.data.trajdata_datamodules import UnifiedDataModule +from diffstack.configs.registry import get_registered_experiment_config +from diffstack.utils.config_utils import ( + get_experiment_config_from_file, + recursive_update_flat, + translate_trajdata_cfg, +) +from diffstack.stacks.stack_factory import stack_factory +from pathlib import Path +import json +import multiprocessing as mp +from diffstack.configs.config import Dict + +test = False + + +def main(cfg, auto_remove_exp_dir=False, debug=False, evaluate=False, **extra_kwargs): + pl.seed_everything(cfg.seed) + # Dataset + trajdata_config = translate_trajdata_cfg(cfg) + + model = stack_factory( + cfg=cfg, + ) + # if test: + # model.set_eval() + # import pickle + # with open("homo_test_case.pkl","rb") as f: + # batch = pickle.load(f) + # from diffstack.modules.module import Module, DataFormat, RunMode + # with torch.no_grad(): + # res = model.components["predictor"]._run_forward(dict(parsed_batch=batch),RunMode.VALIDATE) + # model.components["predictor"].log_pred_image(batch,res,501,"curated_result") + + datamodule = UnifiedDataModule(data_config=trajdata_config, train_config=cfg.train) + + datamodule.setup() + + if not evaluate: + print("\n============= New Training Run with Config =============") + print(cfg) + print("") + root_dir, log_dir, ckpt_dir, video_dir, version_key = TrainUtils.get_exp_dir( + exp_name=cfg.name, + output_dir=cfg.root_dir, + save_checkpoints=cfg.train.save.enabled, + auto_remove_exp_dir=auto_remove_exp_dir, + ) + + # Save experiment config to the training dir + cfg.dump(os.path.join(root_dir, version_key, "config.json")) + + # if cfg.train.logging.terminal_output_to_txt and not debug: + # # log stdout and stderr to a text file + # logger = PrintLogger(os.path.join(log_dir, "log.txt")) + # sys.stdout = logger + # sys.stderr = logger + + train_callbacks = [] + + # Training Parallelism + assert cfg.train.parallel_strategy in [ + "ddp", + "ddp_spawn", + # "ddp", # TODO for ddp we need to look at NODE_RANK and disable logging in config + None, + ] # TODO: look into other strategies + if not cfg.devices.num_gpus > 1: + # Override strategy when training on a single GPU + with cfg.train.unlocked(): + cfg.train.parallel_strategy = None + if cfg.train.parallel_strategy in ["ddp_spawn", "ddp"]: + with cfg.train.training.unlocked(): + cfg.train.training.batch_size = int( + cfg.train.training.batch_size / cfg.devices.num_gpus + ) + with cfg.train.validation.unlocked(): + cfg.train.validation.batch_size = int( + cfg.train.validation.batch_size / cfg.devices.num_gpus + ) + + # # Environment for close-loop evaluation + # if cfg.train.rollout.enabled: + # # Run rollout at regular intervals + # rollout_callback = RolloutCallback( + # exp_config=cfg, + # every_n_steps=cfg.train.rollout.every_n_steps, + # warm_start_n_steps=cfg.train.rollout.warm_start_n_steps, + # verbose=True, + # save_video=cfg.train.rollout.save_video, + # video_dir=video_dir + # ) + # train_callbacks.append(rollout_callback) + + # Model + + # Checkpointing + if cfg.train.validation.enabled and cfg.train.save.save_best_validation: + assert ( + cfg.train.save.every_n_steps > cfg.train.validation.every_n_steps + ), "checkpointing frequency needs to be greater than validation frequency" + for metric_name, metric_key in model.checkpoint_monitor_keys.items(): + print( + "Monitoring metrics {} under alias {}".format( + metric_key, metric_name + ) + ) + ckpt_valid_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}_ep{epoch}_%s{%s:.2f}" + % (metric_name, metric_key), + # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories + auto_insert_metric_name=False, + save_top_k=cfg.train.save.best_k, # save the best k models + monitor=metric_key, + mode="min", + every_n_train_steps=cfg.train.save.every_n_steps, + verbose=True, + ) + train_callbacks.append(ckpt_valid_callback) + + if cfg.train.rollout.enabled and cfg.train.save.save_best_rollout: + assert ( + cfg.train.save.every_n_steps > cfg.train.rollout.every_n_steps + ), "checkpointing frequency needs to be greater than rollout frequency" + ckpt_rollout_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}_ep{epoch}_simADE{rollout/metrics_ego_ADE:.2f}", + # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories + auto_insert_metric_name=False, + save_top_k=cfg.train.save.best_k, # save the best k models + monitor="rollout/metrics_ego_ADE", + mode="min", + every_n_train_steps=cfg.train.save.every_n_steps, + verbose=True, + ) + train_callbacks.append(ckpt_rollout_callback) + + # a ckpt monitor to save at fixed interval + ckpt_fixed_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}", + auto_insert_metric_name=False, + save_top_k=-1, + monitor=None, + every_n_train_steps=10000, + verbose=True, + ) + train_callbacks.append(ckpt_fixed_callback) + + # def wandb_login(i,return_dict): + # apikey = os.environ["WANDB_APIKEY"] + # wandb.login(key=apikey,host="https://api.wandb.ai") + # logger = WandbLogger( + # name=cfg.name, project=cfg.train.logging.wandb_project_name, + # ) + # return_dict[i] = logger + # manager = mp.Manager() + + logger = None + if debug: + print("Debugging mode, suppress logging.") + elif cfg.train.logging.log_tb: + logger = TensorBoardLogger( + save_dir=root_dir, version=version_key, name=None, sub_dir="logs/" + ) + print("Tensorboard event will be saved at {}".format(logger.log_dir)) + elif cfg.train.logging.log_wandb: + assert ( + "WANDB_APIKEY" in os.environ + ), "Set api key by `export WANDB_APIKEY=`" + try: + apikey = os.environ["WANDB_APIKEY"] + wandb.login(key=apikey, host="https://api.wandb.ai") + logger = WandbLogger( + name=cfg.name, + project=cfg.train.logging.wandb_project_name, + ) + except: + logger = None + # return_dict = manager.dict() + # p1 = mp.Process(target=wandb_login,args=(0,return_dict)) + # p1.start() + # p1.join(timeout=30) + # p1.terminate() + # if 0 in return_dict: + # logger = return_dict[0] + + else: + print("WARNING: not logging training stats") + + # Train + kwargs = dict( + default_root_dir=root_dir, + # checkpointing + enable_checkpointing=cfg.train.save.enabled, + # logging + logger=logger, + # flush_logs_every_n_steps=cfg.train.logging.flush_every_n_steps, + log_every_n_steps=cfg.train.logging.log_every_n_steps, + # training + max_steps=cfg.train.training.num_steps, + # validation + val_check_interval=cfg.train.validation.every_n_steps, + limit_val_batches=cfg.train.validation.num_steps_per_epoch, + # all callbacks + callbacks=train_callbacks, + # device & distributed training setup + # accelerator='cpu', + # accelerator=('gpu' if cfg.devices.num_gpus > 0 else 'cpu'), + devices=(cfg.devices.num_gpus if cfg.devices.num_gpus > 0 else None), + # strategy=cfg.train.parallel_strategy if cfg.train.parallel_strategy is not None else DDPStrategy(find_unused_parameters=True), + strategy="ddp_find_unused_parameters_true", + accelerator="gpu", + gradient_clip_val=cfg.train.gradient_clip_val, + # detect_anomaly=True, + # setting for overfit debugging + # limit_val_batches=0, + # overfit_batches=2 + ) + # if debug: + # profiler = PyTorchProfiler( + # output_filename=None, + # enabled=True, + # use_cuda=cfg.devices.num_gpus > 0, + # record_shapes=False, + # profile_memory=True, + # group_by_input_shapes=False, + # with_stack=False, + # use_kineto=False, + # use_cpu=cfg.devices.num_gpus == 0, + # emit_nvtx=False, + # export_to_chrome=False, + # path_to_export_trace=None, + # row_limit=200, + # sort_by_key=None, + # profiled_functions=None, + # local_rank=None, + # ) + # kwargs["profiler"] = profiler + # kwargs["max_steps"] = extra_kwargs["profile_steps"] + + if cfg.train.get("amp", False): + kwargs["precision"] = 16 + # kwargs["amp_backend"] = "apex" + # kwargs["amp_level"] = "O2" + # if cfg.train.get("auto_batch_size",False): + # kwargs["auto_scale_batch_size"] = "binsearch" + + if cfg.train.get("auto_batch_size", False): + from diffstack.utils.train_utils import trajdata_auto_set_batch_size + + kwargs_tune = kwargs.copy() + kwargs_tune["max_steps"] = 3 + trial_trainer = pl.Trainer(**kwargs_tune) + trial_model = stack_factory( + cfg=cfg, + ) + bs_max = cfg.train.get("max_batch_size", None) + batch_size = trajdata_auto_set_batch_size( + trial_trainer, + trial_model, + datamodule, + bs_min=cfg.train.training.batch_size, + bs_max=bs_max, + ) + # batch_size = trajdata_auto_set_batch_size(trial_trainer,trial_model,datamodule,bs_min = 50,bs_max = 58) + datamodule.train_batch_size = batch_size + datamodule.val_batch_size = batch_size + del trial_trainer, trial_model + if cfg.devices.num_gpus > 0: + torch.cuda.empty_cache() + torch.set_float32_matmul_precision("medium") + + trainer = pl.Trainer(**kwargs) + # Logging + assert not (cfg.train.logging.log_tb and cfg.train.logging.log_wandb) + + if isinstance(logger, WandbLogger): + # record the entire config on wandb + if trainer.global_rank == 0: + logger.experiment.config.update(cfg.to_dict()) + logger.watch(model=model) + # kwargs_tune = kwargs.copy() + # kwargs_tune["max_steps"] = 30 + # trial_trainer = pl.Trainer(**kwargs_tune) + # trial_trainer.fit(model=model, datamodule=datamodule) + trainer.fit(model=model, datamodule=datamodule) + + else: + kwargs = dict( + devices=(1 if cfg.devices.num_gpus > 0 else None), + # strategy=cfg.train.parallel_strategy if cfg.train.parallel_strategy is not None else DDPStrategy(find_unused_parameters=True), + accelerator="auto", + ) + trainer = pl.Trainer(**kwargs) + # Evaluation + + model.set_eval() + + metrics = trainer.test(model, datamodule=datamodule) + if len(metrics) == 1: + flattened_metrics = metrics[0] + else: + flattened_metrics = dict() + for k, v in metrics[0].items(): + flattened_metrics[k] = sum([m[k] for m in metrics]) / len(metrics) + result_path = extra_kwargs.get( + "eval_output_dir", os.path.join(os.getcwd(), "eval_result") + ) + file_name = f"{cfg.registered_name}_{cfg.train.trajdata_test_source_root}_{cfg.train.trajdata_source_test}_eval_result.json" + + with open(os.path.join(result_path, file_name), "w") as f: + json.dump(flattened_metrics, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # External config file that overwrites default config + parser.add_argument( + "--config_file", + type=str, + default=None, + help="(optional) path to a config json that will be used to override the default settings. \ + If omitted, default settings are used. This is the preferred way to run experiments.", + ) + + parser.add_argument( + "--config_name", + type=str, + default=None, + help="(optional) create experiment config from a preregistered name (see configs/registry.py)", + ) + # Experiment Name (for tensorboard, saving models, etc.) + parser.add_argument( + "--name", + type=str, + default=None, + help="(optional) if provided, override the experiment name defined in the config", + ) + + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help="(optional) if provided, override the wandb project name defined in the config", + ) + + parser.add_argument( + "--dataset_path", + type=str, + default=None, + help="(optional) if provided, override the dataset root path", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Root directory of training output (checkpoints, visualization, tensorboard log, etc.)", + ) + + parser.add_argument( + "--remove_exp_dir", + action="store_true", + help="Whether to automatically remove existing experiment directory of the same name (remember to set this to " + "True to avoid unexpected stall when launching cloud experiments).", + ) + + parser.add_argument( + "--on_ngc", + action="store_true", + help="whether running the script on ngc (this will change some behaviors like avoid writing into dataset)", + ) + + parser.add_argument( + "--debug", action="store_true", help="Debug mode, suppress wandb logging, etc." + ) + + parser.add_argument( + "--profile_steps", type=int, default=500, help="number of steps to run profiler" + ) + # evaluation mode + parser.add_argument( + "--evaluate", action="store_true", help="Evaluate mode, suppress training." + ) + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="path to the checkpoint to be evaluated", + ) + parser.add_argument( + "--eval_output_dir", + type=str, + default=None, + help="path to store evaluation result", + ) + parser.add_argument( + "--ckpt_root_dir", type=str, default=None, help="path of ngc checkpoint folder" + ) + parser.add_argument("--ngc_job_id", type=str, default=None, help="ngc job id") + parser.add_argument("--ckpt_key", type=str, default=None, help="ngc checkpoint key") + parser.add_argument( + "--test_data", type=str, default=None, help="trajdata_source_test in config" + ) + parser.add_argument( + "--test_data_root", + type=str, + default=None, + help="trajdata_test_source_root in config", + ) + parser.add_argument( + "--test_batch_size", type=int, default=None, help="batch size of test run" + ) + parser.add_argument( + "--log_image_frequency", type=int, default=None, help="image logging frequency" + ) + + parser.add_argument( + "--log_all_image", + action="store_true", + help="log images for all scenes instead of only the first one for every batch", + ) + + parser.add_argument( + "--remove_parked", + action="store_true", + help="remove parked agents from the scene", + ) + + args = parser.parse_args() + + if args.config_name is not None: + default_config = get_registered_experiment_config(args.config_name) + elif args.config_file is not None: + # Update default config with external json file + default_config = get_experiment_config_from_file(args.config_file, locked=False) + elif args.evaluate: + ckpt_path = None + if args.ckpt_path is not None: + config_path = str(Path(args.ckpt_path).parents[1] / "config.json") + ckpt_path = args.ckpt_path + elif ( + args.ngc_job_id is not None + and args.ckpt_key is not None + and args.ckpt_root_dir is not None + ): + ckpt_path, config_path = get_checkpoint( + ngc_job_id=args.ngc_job_id, + ckpt_key=args.ckpt_key, + ckpt_root_dir=args.ckpt_root_dir, + ) + default_config = get_experiment_config_from_file(config_path) + + if default_config.stack.stack_type == "pred": + default_config.stack.predictor.load_checkpoint = ckpt_path + elif default_config.stack.stack_type == "plan": + default_config.stack.planner.load_checkpoint = ckpt_path + else: + raise NotImplementedError + + if "test" not in default_config.train: + default_config.train.test = Dict( + enabled=True, + batch_size=32, + num_data_workers=6, + every_n_steps=500, + num_steps_per_epoch=50, + ) + if args.remove_parked: + default_config.env.remove_parked = True + if args.test_data is not None: + default_config.train.trajdata_source_test = args.test_data + if args.test_data_root is not None: + default_config.train.trajdata_test_source_root = args.test_data_root + # modify dataset path to reduce loading time + if default_config.train.trajdata_source_test is None: + default_config.train.trajdata_source_test = ( + default_config.train.trajdata_source_valid + ) + if default_config.train.trajdata_test_source_root is None: + default_config.train.trajdata_test_source_root = ( + default_config.train.trajdata_val_source_root + ) + + default_config.train.trajdata_val_source_root = None + default_config.train.trajdata_source_valid = ( + default_config.train.trajdata_source_test + ) + default_config.train.trajdata_source_root = ( + default_config.train.trajdata_test_source_root + ) + default_config.train.trajdata_source_train = ( + default_config.train.trajdata_source_test + ) + if args.test_batch_size is not None: + default_config.train.test["batch_size"] = args.test_batch_size + if ( + "predictor" in default_config.stack + and default_config.stack.predictor.name == "CTT" + ): + default_config.env.max_num_lanes = 64 + default_config.stack.predictor.decoder.decode_num_modes = 8 + default_config.stack.predictor.LR_sample_hack = False + # default_config.env.remove_single_successor = False + default_config.eval.results_dir = args.eval_output_dir + if args.ngc_job_id is not None: + default_config.registered_name = ( + default_config.registered_name + "_" + args.ngc_job_id + ) + default_config.eval.log_image_frequency = args.log_image_frequency + if args.log_all_image: + default_config.eval.log_all_image = True + else: + default_config.eval.log_all_image = False + + else: + raise Exception( + "Need either a config name or a json file to create experiment config" + ) + + if args.name is not None: + default_config.name = args.name + + if args.dataset_path is not None: + default_config.train.dataset_path = args.dataset_path + + if args.output_dir is not None: + default_config.root_dir = os.path.abspath(args.output_dir) + + if args.wandb_project_name is not None: + default_config.train.logging.wandb_project_name = args.wandb_project_name + + if args.on_ngc: + ngc_job_id = socket.gethostname() + default_config.name = default_config.name + "_" + ngc_job_id + + default_config.train.on_ngc = args.on_ngc + args_dict = vars(args) + args_dict = { + k: v + for k, v in args_dict.items() + if k + not in ["name", "dataset_path", "output_dir", "wandb_project_name", "on_ngc"] + } + default_config, leftover = recursive_update_flat(default_config, args_dict) + if len(leftover) > 0: + Warning(f"Arguments {list(leftover.keys())} are not found in the config") + + if args.debug: + # Test policy rollout + default_config.train.rollout.every_n_steps = 10 + default_config.train.rollout.num_episodes = 1 + + # make rollout evaluation config consistent with the rest of the config + + default_config.lock() # Make config read-only + main( + default_config, + auto_remove_exp_dir=args.remove_exp_dir, + debug=args.debug, + profile_steps=args.profile_steps, + evaluate=args.evaluate, + eval_output_dir=args.eval_output_dir, + ) diff --git a/diffstack/stacks/base.py b/diffstack/stacks/base.py new file mode 100644 index 0000000..8c93cb4 --- /dev/null +++ b/diffstack/stacks/base.py @@ -0,0 +1,347 @@ +from diffstack.modules.module import ModuleSequence, DataFormat +from trajdata.data_structures.batch import SceneBatch, AgentBatch +import pytorch_lightning as pl +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.utils import removeprefix +import torch +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +import wandb + + +class AVStack(pl.LightningModule): + def __init__(self, modules: ModuleSequence, cfg, batch_size=None, **kwargs): + super(AVStack, self).__init__() + self.moduleseq = modules + self.components = nn.ModuleDict() + for k, v in self.moduleseq.components.items(): + self.components[k] = v + self.cfg = cfg + if "monitor_key" in kwargs: + self.monitor_key = kwargs["monitor_key"] + else: + self.monitor_key = {"valLoss": "val/losses_predictor_prediction_loss"} + + self.batch_size = batch_size + self.validation_step_outputs = [] + + @property + def input_format(self) -> DataFormat: + self.moduleseq.input_format + + @property + def output_format(self) -> DataFormat: + self.moduleseq.output_format + + @property + def checkpoint_monitor_keys(self): + return self.monitor_key + + def forward(self, inputs, **kwargs): + return self.moduleseq(inputs, **kwargs) + + def infer_step(self, inputs, **kwargs): + with torch.no_grad(): + return TensorUtils.detach(self.moduleseq.infer_step(inputs, **kwargs)) + + def _compute_losses(self, pout, inputs): + loss_by_component = dict() + for comp_name, component in self.moduleseq.components.items(): + comp_out = { + removeprefix(k, comp_name + "."): v + for k, v in pout.items() + if k.startswith(comp_name + ".") + } + loss_by_component[comp_name] = component.compute_losses(comp_out, inputs) + + return loss_by_component + + def training_step(self, batch, batch_idx): + """ + Training on a single batch of data. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader and filtered by @process_batch_for_training + + batch_idx (int): training step number - required by some Algos that need + to perform staged training and early stopping + + Returns: + total_loss(torch.Tensor): total weighted loss + """ + + pout = self(batch, batch_idx=batch_idx) + losses = self._compute_losses(pout, batch) + + total_loss = 0.0 + for comp_name in self.moduleseq.components: + comp_loss = losses[comp_name] + for lk, l in comp_loss.items(): + loss = l * self.cfg.stack[comp_name].loss_weights[lk] + self.log("train/losses_" + comp_name + "_" + lk, loss, sync_dist=True) + total_loss += loss + if batch_idx % 100 == 0: + metrics = self._compute_metrics(pout, batch) + for mk, m in metrics.items(): + if isinstance(m, np.ndarray) or isinstance(m, torch.Tensor): + m = m.mean() + self.log("train/metrics_" + mk, m, sync_dist=True) + + return total_loss + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + pout = TensorUtils.detach( + self.moduleseq.validate_step(batch, batch_idx=batch_idx) + ) + losses = self._compute_losses(pout, batch) + + if self.logger is not None and batch_idx == 1: + if "predictor" in self.components and hasattr( + self.components["predictor"], "log_pred_image" + ): + self.components["predictor"].log_pred_image( + batch, pout, batch_idx, self.logger + ) + else: + # default + self.log_pred_image(batch, pout, batch_idx, self.logger) + print("image logged") + + metrics = self._compute_metrics(pout, batch) + pred = {"losses": losses, "metrics": metrics} + self.validation_step_outputs.append(pred) + return pred + + def test_step(self, batch, batch_idx): + # if batch_idx<380 or batch_idx>381: + # return {} + # if "log_image_frequency" in self.cfg.eval and self.cfg.eval.log_image_frequency is not None: + # if batch_idx%self.cfg.eval.log_image_frequency!=0: + # return {} + with torch.no_grad(): + pout = TensorUtils.detach( + self.moduleseq.validate_step(batch, batch_idx=batch_idx) + ) + losses = self._compute_losses(pout, batch) + metrics = self._compute_metrics(pout, batch) + pred = {"losses": losses, "metrics": metrics} + flattened_pred = TensorUtils.flatten_dict(pred) + for k, v in flattened_pred.items(): + if isinstance(v, torch.Tensor): + flattened_pred[k] = ( + v.cpu().numpy().item() + if v.numel() == 1 + else v.cpu().numpy().mean().item() + ) + elif isinstance(v, np.ndarray): + flattened_pred[k] = v.item() if v.size == 1 else v.mean().item() + + self.log_dict(flattened_pred) + + if ( + "log_image_frequency" in self.cfg.eval + and self.cfg.eval.log_image_frequency is not None + ): + if batch_idx % self.cfg.eval.log_image_frequency == 0: + for comp_name, component in self.moduleseq.components.items(): + if hasattr(component, "log_pred_image"): + component.log_pred_image( + batch, + pout, + batch_idx, + self.cfg.eval.results_dir, + log_all_image=self.cfg.eval.log_all_image, + savegif=self.cfg.eval.get("savegif", True), + ) + + return flattened_pred + + def on_validation_epoch_end(self) -> None: + outputs = self.validation_step_outputs + for comp_name in self.moduleseq.components: + for k in outputs[0]["losses"][comp_name]: + m = torch.stack([o["losses"][comp_name][k] for o in outputs]).mean() + self.log("val/losses_" + comp_name + "_" + k, m, sync_dist=True) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m, sync_dist=True) + self.validation_step_outputs = [] + + def configure_optimizers(self): + pass + + def _compute_metrics(self, pred_batch, data_batch): + pass + + def log_pred_image( + self, + batch, + pred, + batch_idx, + logger, + **kwargs, + ): + if "pred" in pred: + pred = pred["pred"] + if "image" in pred: + N = pred["image"].shape[0] + h = int(np.ceil(np.sqrt(N))) + w = int(np.ceil(N / h)) + fig, ax = plt.subplots(h, w, figsize=(20 * h, 20 * w)) + canvas = FigureCanvas(fig) + image = pred["image"].detach().cpu().numpy().transpose(0, 2, 3, 1) + for i in range(N): + if h > 1 and w > 1: + hi = int(i / w) + wi = i - w * hi + ax[hi, wi].imshow(image[i]) + + elif h == 1 and w == 1: + ax.imshow(image[i]) + else: + ax[i].imshow(image[i]) + canvas.draw() # draw the canvas, cache the renderer + + image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + logger.experiment.log( + {"val_pred_image": [wandb.Image(image, caption="val_pred_image")]} + ) + del fig, ax, image + else: + if "data_batch" in pred: + batch = pred["data_batch"] + else: + if "scene_batch" in batch: + batch = batch["scene_batch"] + elif "agent_batch" in batch: + batch = batch["agent_batch"] + if isinstance(batch, SceneBatch) or isinstance(batch, AgentBatch): + batch = batch.__dict__ + + if "trajectories" not in pred or "image" not in batch: + return + + keys = ["trajectories", "agent_avail", "cond_traj"] + pred = {k: v for k, v in pred.items() if k in keys} + pred = TensorUtils.to_numpy(pred) + + if batch["image"].ndim == 5: + map = batch["image"][0, 0, -3:].cpu().numpy().transpose(1, 2, 0) + else: + map = batch["image"][0, -3:].cpu().numpy().transpose(1, 2, 0) + traj = pred["trajectories"][..., :2] + + avail = pred["agent_avail"].astype(float) + avail1 = avail.copy() + avail1[avail == 0] = np.nan + raster_from_agent = batch["raster_from_agent"].cpu().numpy() + if raster_from_agent.ndim == 2: + raster_from_agent = raster_from_agent[np.newaxis, :] + if traj.ndim == 6: + bs, Ne, numMode, Na, T = traj.shape[:5] + traj = traj * avail1[:, None, None, :, None, None] + else: + bs, numMode, Na, T = traj.shape[:4] + traj = traj * avail1[:, None, :, None, None] + Ne = 1 + raster_traj = GeoUtils.batch_nd_transform_points_np( + traj.reshape(bs, -1, 2), raster_from_agent + ).reshape(bs, Ne, -1, 2) + + cond_traj = pred["cond_traj"] if "cond_traj" in pred else None + + if cond_traj is None: + fig, ax = plt.subplots(figsize=(20, 20)) + canvas = FigureCanvas(fig) + ax.imshow(map) + ax.scatter( + raster_traj[0, 0, ..., 0], + raster_traj[0, 0, ..., 1], + color="c", + s=2, + marker="D", + ) + + else: + raster_cond_traj = GeoUtils.batch_nd_transform_points_np( + cond_traj[..., :2].reshape(bs, -1, 2), raster_from_agent + ).reshape(bs, Ne, -1, 2) + h = int(np.ceil(np.sqrt(Ne))) + w = int(np.ceil(Ne / h)) + fig, ax = plt.subplots(h, w, figsize=(20 * h, 20 * w)) + canvas = FigureCanvas(fig) + for i in range(Ne): + if h > 1 and w > 1: + hi = int(i / w) + wi = i - w * hi + ax[hi, wi].imshow(map) + ax[hi, wi].scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax[hi, wi].scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + elif h == 1 and w == 1: + ax.imshow(map) + ax.scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax.scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + else: + ax[i].imshow(map) + ax[i].scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax[i].scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + canvas.draw() # draw the canvas, cache the renderer + + image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + logger.experiment.log( + {"val_pred_image": [wandb.Image(image, caption="val_pred_image")]} + ) + del fig, ax, image + + def set_eval(self): + self.moduleseq.set_eval() + + def set_train(self): + self.moduleseq.set_train() + + def reset(self): + self.moduleseq.reset() diff --git a/diffstack/stacks/pred_stack.py b/diffstack/stacks/pred_stack.py new file mode 100644 index 0000000..f023130 --- /dev/null +++ b/diffstack/stacks/pred_stack.py @@ -0,0 +1,44 @@ +from diffstack.modules.module import Module, ModuleSequence, DataFormat +from trajdata import SceneBatch +import pytorch_lightning as pl +import diffstack.utils.tensor_utils as TensorUtils +import torch +import torch.nn as nn +import numpy as np +import torch.optim as optim +from diffstack.stacks.base import AVStack + + +class PredStack(AVStack): + def __init__(self, modules: ModuleSequence, cfg, batch_size=None, **kwargs): + super().__init__( + modules, cfg, batch_size=cfg.train.training.batch_size, **kwargs + ) + + def configure_optimizers(self): + optim_params = self.cfg.stack["predictor"].optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def _compute_metrics(self, pred_batch, data_batch): + metrics_dict = dict() + if hasattr(self.components["predictor"], "compute_metrics"): + metrics_dict.update( + self.components["predictor"].compute_metrics(pred_batch, data_batch) + ) + if ( + "trajectories" in pred_batch + and pred_batch["trajectories"].ndim == 5 + and pred_batch["trajectories"].size(1) > 0 + ): + metrics_dict["mode_diversity"] = ( + (pred_batch["trajectories"][:, 0] - pred_batch["trajectories"][:, 1]) + .norm() + .detach() + .cpu() + ) + + return metrics_dict diff --git a/diffstack/stacks/stack_factory.py b/diffstack/stacks/stack_factory.py new file mode 100644 index 0000000..7927497 --- /dev/null +++ b/diffstack/stacks/stack_factory.py @@ -0,0 +1,82 @@ +"""Factory methods for creating models""" +from diffstack.modules.module import ModuleSequence +from diffstack.modules.predictors.factory import ( + predictor_factory, +) +from diffstack.configs.config import Dict +import torch +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import List + +from diffstack.stacks.pred_stack import PredStack +from omegaconf import DictConfig, OmegaConf, open_dict + + +def get_checkpoint_dict(stack_cfg: DictConfig, module_names: List, device: str): + checkpoint_dict = defaultdict(lambda: None) + for module_name in module_names: + if ( + "load_checkpoint" in stack_cfg[module_name] + and stack_cfg[module_name].load_checkpoint + ): + ckpt_path = Path(stack_cfg[module_name].load_checkpoint).expanduser() + checkpoint_dict[module_name] = torch.load(ckpt_path, map_location=device) + return checkpoint_dict + + +def stack_factory(cfg: DictConfig, model_registrar=None, log_writer=None, device=None): + """ + A factory for creating training stacks + + Args: + cfg (ExperimentConfig): an ExperimentConfig object, + Returns: + stack: pl.LightningModule + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + if cfg.stack.stack_type == "pred": + predictor_config = cfg.stack["predictor"] + + checkpoint_dict = get_checkpoint_dict(cfg.stack, ["predictor"], device) + if "env" in cfg: + if isinstance(predictor_config, Dict): + predictor_config.unlock() + predictor_config.env = cfg.env + predictor_config.lock() + elif isinstance(predictor_config, DictConfig): + OmegaConf.set_struct(predictor_config, True) + with open_dict(predictor_config): + predictor_config.env = cfg.env + predictor = predictor_factory( + model_registrar, + predictor_config, + log_writer, + device, + checkpoint=checkpoint_dict["predictor"], + ) + + modules = ModuleSequence( + OrderedDict(predictor=predictor), + model_registrar, + predictor_config, + log_writer, + device, + ) + if ( + "monitor_key" in predictor_config + and predictor_config["monitor_key"] is not None + ): + avstack = PredStack( + modules, cfg, monitor_key=predictor_config["monitor_key"] + ) + else: + avstack = PredStack( + modules, cfg, monitor_key=predictor.checkpoint_monitor_keys + ) + return avstack + + else: + raise NotImplementedError("The type of stack structure is not recognized!") diff --git a/diffstack/utils/__init__.py b/diffstack/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffstack/utils/algo_utils.py b/diffstack/utils/algo_utils.py new file mode 100644 index 0000000..2d77661 --- /dev/null +++ b/diffstack/utils/algo_utils.py @@ -0,0 +1,283 @@ +import torch +from torch import optim as optim + +import diffstack.utils.tensor_utils as TensorUtils +from diffstack import dynamics as dynamics +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import transform_points_tensor, calc_distance_map +from diffstack.utils.l5_utils import get_last_available_index +from diffstack.utils.loss_utils import goal_reaching_loss, trajectory_loss, collision_loss + + +def generate_proxy_mask(orig_loc, radius,mode="L1"): + """ mask out area near the existing samples to boost diversity + + Args: + orig_loc (torch.tensor): original sample location, 1 for sample, 0 for background + radius (int): radius in pixel space + mode (str, optional): Defaults to "L1". + + Returns: + torch.tensor[dtype=torch.bool]: mask for generating new samples + """ + dis_map = calc_distance_map(orig_loc,max_dis = radius+1,mode=mode) + return dis_map<=radius + + +def decode_spatial_prediction(prob_map, residual_yaw_map, num_samples=None, clearance=None): + """ + Decode spatial predictions (e.g., UNet output) to a list of locations + Args: + prob_map (torch.Tensor): probability of each spatial location [B, H, W] + residual_yaw_map (torch.Tensor): location residual (d_x, d_yy) and yaw of each location [B, 3, H, W] + num_samples (int): (optional) if specified, take # of samples according to the discrete prob_map distribution. + default is None, which is to take the max + Returns: + pixel_loc (torch.Tensor): coordinates of each predicted location before applying residual [B, N, 2] + residual_pred (torch.Tensor): residual of each predicted location [B, N, 2] + yaw_pred (torch.Tensor): yaw of each predicted location [B, N, 1] + pixel_prob (torch.Tensor): probability of each sampled prediction [B, N] + """ + # decode map as predictions + b, h, w = prob_map.shape + flat_prob_map = prob_map.flatten(start_dim=1) + if num_samples is None: + # if num_samples is not specified, take the maximum-probability location + pixel_prob, pixel_loc_flat = torch.max(flat_prob_map, dim=1) + pixel_prob = pixel_prob.unsqueeze(1) + pixel_loc_flat = pixel_loc_flat.unsqueeze(1) + else: + # otherwise, use the probability map as a discrete distribution of location predictions + dist = torch.distributions.Categorical(probs=flat_prob_map) + if clearance is None: + pixel_loc_flat = dist.sample((num_samples,)).permute( + 1, 0) # [n_sample, batch] -> [batch, n_sample] + else: + proximity_map = torch.ones_like(prob_map,requires_grad=False) + pixel_loc_flat = list() + for i in range(num_samples): + dist = torch.distributions.Categorical(probs=flat_prob_map*proximity_map.flatten(start_dim=1)) + sample_i = dist.sample() + sample_mask = torch.zeros_like(flat_prob_map,dtype=torch.bool) + sample_mask[torch.arange(b),sample_i]=True + proxy_mask = generate_proxy_mask(sample_mask.reshape(-1,h,w),clearance) + proximity_map = torch.logical_or(proximity_map,torch.logical_not(proxy_mask)) + pixel_loc_flat.append(sample_i) + + pixel_loc_flat = torch.stack(pixel_loc_flat,dim=1) + pixel_prob = torch.gather(flat_prob_map, dim=1, index=pixel_loc_flat) + + local_pred = torch.gather( + input=torch.flatten(residual_yaw_map, 2), # [B, C, H * W] + dim=2, + index=TensorUtils.unsqueeze_expand_at( + pixel_loc_flat, size=3, dim=1) # [B, C, num_samples] + ).permute(0, 2, 1) # [B, C, N] -> [B, N, C] + + residual_pred = local_pred[:, :, 0:2] + yaw_pred = local_pred[:, :, 2:3] + + pixel_loc_x = torch.remainder(pixel_loc_flat, w).float() + pixel_loc_y = torch.floor(pixel_loc_flat.float() / float(w)).float() + pixel_loc = torch.stack((pixel_loc_x, pixel_loc_y), dim=-1) # [B, N, 2] + + return pixel_loc, residual_pred, yaw_pred, pixel_prob + + +def get_spatial_goal_supervision(data_batch): + """Get supervision for training the spatial goal network.""" + b, _, h, w = data_batch["image"].shape # [B, C, H, W] + + # use last available step as goal location + goal_index = get_last_available_index( + data_batch["fut_mask"])[:, None, None] + + # gather by goal index + goal_pos_agent = torch.gather( + data_batch["fut_pos"], # [B, T, 2] + dim=1, + index=goal_index.expand(-1, 1, + data_batch["fut_pos"].shape[-1]) + ) # [B, 1, 2] + + goal_yaw_agent = torch.gather( + data_batch["fut_yaw"], # [B, T, 1] + dim=1, + index=goal_index.expand(-1, 1, data_batch["fut_yaw"].shape[-1]) + ) # [B, 1, 1] + + # create spatial supervisions + goal_pos_raster = transform_points_tensor( + goal_pos_agent, + data_batch["raster_from_agent"].float() + ).squeeze(1) # [B, 2] + # make sure all pixels are within the raster image + goal_pos_raster[:, 0] = goal_pos_raster[:, 0].clip(0, w - 1e-5) + goal_pos_raster[:, 1] = goal_pos_raster[:, 1].clip(0, h - 1e-5) + + goal_pos_pixel = torch.floor(goal_pos_raster).float() # round down pixels + # compute rounding residuals (range 0-1) + goal_pos_residual = goal_pos_raster - goal_pos_pixel + # compute flattened pixel location + goal_pos_pixel_flat = goal_pos_pixel[:, 1] * w + goal_pos_pixel[:, 0] + raster_sup_flat = TensorUtils.to_one_hot( + goal_pos_pixel_flat.long(), num_class=h * w) + raster_sup = raster_sup_flat.reshape(b, h, w) + return { + "goal_position_residual": goal_pos_residual, # [B, 2] + "goal_spatial_map": raster_sup, # [B, H, W] + "goal_position_pixel": goal_pos_pixel, # [B, 2] + "goal_position_pixel_flat": goal_pos_pixel_flat, # [B] + "goal_position": goal_pos_agent.squeeze(1), # [B, 2] + "goal_yaw": goal_yaw_agent.squeeze(1), # [B, 1] + "goal_index": goal_index.reshape(b) # [B] + } + + +def get_spatial_trajectory_supervision(data_batch): + """Get supervision for training the learned occupancy metric.""" + b, _, h, w = data_batch["image"].shape # [B, C, H, W] + t = data_batch["fut_pos"].shape[-2] + # create spatial supervisions + pos_raster = transform_points_tensor( + data_batch["fut_pos"], + data_batch["raster_from_agent"].float() + ) # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + pos_pixel = torch.floor(pos_raster).float() # round down pixels + + # compute flattened pixel location + pos_pixel_flat = pos_pixel[..., 1] * w + pos_pixel[..., 0] + raster_sup_flat = TensorUtils.to_one_hot( + pos_pixel_flat.long(), num_class=h * w) + raster_sup = raster_sup_flat.reshape(b, t, h, w) + return { + "traj_spatial_map": raster_sup, # [B, T, H, W] + "traj_position_pixel": pos_pixel, # [B, T, 2] + "traj_position_pixel_flat": pos_pixel_flat # [B, T] + } + + +def optimize_trajectories( + init_u, + init_x, + target_trajs, + target_avails, + dynamics_model, + step_time: float, + data_batch=None, + goal_loss_weight=1.0, + traj_loss_weight=0.0, + coll_loss_weight=0.0, + num_optim_iterations: int = 50 +): + """An optimization-based trajectory generator""" + curr_u = init_u.detach().clone() + curr_u.requires_grad = True + action_optim = optim.LBFGS( + [curr_u], max_iter=20, lr=1.0, line_search_fn='strong_wolfe') + + for oidx in range(num_optim_iterations): + def closure(): + action_optim.zero_grad() + + # get trajectory with current params + x = dynamics_model.forward_dynamics( + x0=init_x, + u=curr_u, + ) + pos = dynamics_model.state2pos(x) + yaw = dynamics_model.state2yaw(x) + curr_trajs = torch.cat((pos, yaw), dim=-1) + # compute trajectory optimization losses + losses = dict() + losses["goal_loss"] = goal_reaching_loss( + predictions=curr_trajs, + targets=target_trajs, + availabilities=target_avails + ) * goal_loss_weight + losses["traj_loss"] = trajectory_loss( + predictions=curr_trajs, + targets=target_trajs, + availabilities=target_avails + ) * traj_loss_weight + if coll_loss_weight > 0: + assert data_batch is not None + coll_edges = batch_utils().get_edges_from_batch( + data_batch, + ego_predictions=dict(positions=pos, yaws=yaw) + ) + for c in coll_edges: + coll_edges[c] = coll_edges[c][:, :target_trajs.shape[-2]] + vv_edges = dict(VV=coll_edges["VV"]) + if vv_edges["VV"].shape[0] > 0: + losses["coll_loss"] = collision_loss( + vv_edges) * coll_loss_weight + + total_loss = torch.hstack(list(losses.values())).sum() + + # backprop + total_loss.backward() + return total_loss + action_optim.step(closure) + + final_raw_trajs = dynamics_model.forward_dynamics( + x0=init_x, + u=curr_u, + ) + final_pos = dynamics_model.state2pos(final_raw_trajs) + final_yaw = dynamics_model.state2yaw(final_raw_trajs) + final_trajs = torch.cat((final_pos, final_yaw), dim=-1) + losses = dict() + losses["goal_loss"] = goal_reaching_loss( + predictions=final_trajs, + targets=target_trajs, + availabilities=target_avails + ) + losses["traj_loss"] = trajectory_loss( + predictions=final_trajs, + targets=target_trajs, + availabilities=target_avails + ) + + return dict(positions=final_pos, yaws=final_yaw), final_raw_trajs, curr_u, losses + + +def combine_ego_agent_data(batch, ego_keys, agent_keys, mask=None): + assert len(ego_keys) == len(agent_keys) + combined_batch = dict() + for ego_key, agent_key in zip(ego_keys, agent_keys): + if mask is None: + size_dim0 = batch[agent_key].shape[0]*batch[agent_key].shape[1] + combined_batch[ego_key] = torch.cat((batch[ego_key], batch[agent_key].reshape( + size_dim0, *batch[agent_key].shape[2:])), dim=0) + else: + size_dim0 = mask.sum() + combined_batch[ego_key] = torch.cat((batch[ego_key], batch[agent_key][mask].reshape( + size_dim0, *batch[agent_key].shape[2:])), dim=0) + return combined_batch + + +def yaw_from_pos(pos: torch.Tensor, dt, yaw_correction_speed=0.): + """ + Compute yaws from position sequences. Optionally suppress yaws computed from low-velocity steps + + Args: + pos (torch.Tensor): sequence of positions [..., T, 2] + dt (float): delta timestep to compute speed + yaw_correction_speed (float): zero out yaw change when the speed is below this threshold (noisy heading) + + Returns: + accum_yaw (torch.Tensor): sequence of yaws [..., T-1, 1] + """ + + pos_diff = pos[..., 1:, :] - pos[..., :-1, :] + yaw = torch.atan2(pos_diff[..., 1], pos_diff[..., 0]) + delta_yaw = torch.cat((yaw[..., [0]], yaw[..., 1:] - yaw[..., :-1]), dim=-1) + speed = torch.norm(pos_diff, dim=-1) / dt + delta_yaw[speed < yaw_correction_speed] = 0. + accum_yaw = torch.cumsum(delta_yaw, dim=-1) + return accum_yaw[..., None] diff --git a/diffstack/utils/batch_utils.py b/diffstack/utils/batch_utils.py new file mode 100644 index 0000000..ec5747a --- /dev/null +++ b/diffstack/utils/batch_utils.py @@ -0,0 +1,299 @@ +import torch + +import diffstack.utils.trajdata_utils as av_utils +from diffstack import dynamics as dynamics +from diffstack.configs.base import ExperimentConfig + + +global BATCH_TYPE + +BATCH_TYPE = "trajdata" +# def set_global_batch_type(batch_type): +# global BATCH_TYPE +# assert batch_type in ["trajdata", "l5kit"] +# BATCH_TYPE = batch_type + + +def batch_utils(**kwargs): + if BATCH_TYPE == "trajdata": + return trajdataBatchUtils(**kwargs) + else: + raise NotImplementedError( + "Please set BATCH_TYPE in batch_utils.py to {trajdata, l5kit}" + ) + + +class BatchUtils(object): + """A base class for processing environment-independent batches""" + + def __init__(self, **kwargs): + if "parse" in kwargs: + self.parse = kwargs["parse"] + else: + self.parse = True + if "rasterize_mode" in kwargs: + self.rasterize_mode = kwargs["rasterize_mode"] + else: + self.rasterize_mode = "point" + + @staticmethod + def get_last_available_index(avails): + """ + Args: + avails (torch.Tensor): target availabilities [B, (A), T] + + Returns: + last_indices (torch.Tensor): index of the last available frame + """ + num_frames = avails.shape[-1] + inds = torch.arange(0, num_frames).to(avails.device) # [T] + inds = ( + avails > 0 + ).float() * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(dim=-1)[ + 1 + ] # [B, (A)] calculate the index of the last availale frame + return last_inds + + @staticmethod + def get_current_states(batch: dict, dyn_type: dynamics.DynType) -> torch.Tensor: + """Get the dynamic states of the current timestep""" + bs = batch["curr_speed"].shape[0] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, 6).to( + batch["curr_speed"].device + ) # [x, y, yaw, vel, dh, veh_len] + current_states[:, 3] = batch["curr_speed"].abs() + current_states[:, [4]] = ( + batch["hist_yaw"][:, 0] - batch["hist_yaw"][:, 1] + ).abs() + current_states[:, 5] = batch["extent"][:, 0] # [veh_len] + else: + current_states = torch.zeros(bs, 4).to( + batch["curr_speed"].device + ) # [x, y, vel, yaw] + current_states[:, 2] = batch["curr_speed"] + return current_states + + @classmethod + def get_current_states_all_agents( + cls, batch: dict, step_time, dyn_type: dynamics.DynType + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def parse_batch(data_batch): + raise NotImplementedError + + @staticmethod + def batch_to_raw_all_agents(data_batch, step_time): + raise NotImplementedError + + @staticmethod + def batch_to_target_all_agents(data_batch): + raise NotImplementedError + + @staticmethod + def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raise NotImplementedError + + @staticmethod + def generate_edges(raw_type, extents, pos_pred, yaw_pred): + raise NotImplementedError + + @staticmethod + def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ): + raise NotImplementedError + + @staticmethod + def gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask=None, + ): + raise NotImplementedError + + @staticmethod + def get_drivable_region_map(rasterized_map): + raise NotImplementedError + + @staticmethod + def get_modality_shapes(cfg: ExperimentConfig): + raise NotImplementedError + + +class trajdataBatchUtils(BatchUtils): + """Batch utils for trajdata""" + + def parse_batch(self, data_batch): + if self.parse: + return av_utils.parse_trajdata_batch(data_batch, self.rasterize_mode) + else: + return data_batch + + @staticmethod + def batch_to_raw_all_agents(data_batch, step_time): + raw_type = torch.cat( + (data_batch["agent_type"].unsqueeze(1), data_batch["neigh_types"]), + dim=1, + ).type(torch.int64) + if data_batch["num_neigh"].max() > 0: + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["neigh_hist_pos"], + ), + dim=1, + ) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["neigh_hist_yaw"], + ), + dim=1, + ) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["neigh_hist_mask"], + ), + dim=1, + ).bool() + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + data_batch["neigh_extents"][..., :2], + ), + dim=1, + ) + + curr_speed = torch.cat( + (data_batch["curr_speed"].unsqueeze(1), data_batch["neigh_curr_speed"]), + dim=1, + ) + else: + src_pos = data_batch["hist_pos"].unsqueeze(1) + src_yaw = data_batch["hist_yaw"].unsqueeze(1) + src_mask = data_batch["hist_mask"].unsqueeze(1) + extents = data_batch["extent"][..., :2].unsqueeze(1) + curr_speed = data_batch["curr_speed"].unsqueeze(1) + + return { + "hist_pos": src_pos, + "hist_yaw": src_yaw, + "curr_speed": curr_speed, + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": extents, + } + + @staticmethod + def batch_to_target_all_agents(data_batch): + pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["neigh_fut_pos"], + ), + dim=1, + ) + yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["neigh_fut_yaw"], + ), + dim=1, + ) + avails = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["neigh_fut_mask"], + ), + dim=1, + ) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + data_batch["neigh_extents"][..., :2], + ), + dim=1, + ) + + return {"fut_pos": pos, "fut_yaw": yaw, "fut_mask": avails, "extents": extents} + + @staticmethod + def get_current_states_all_agents( + batch: dict, step_time, dyn_type: dynamics.DynType + ) -> torch.Tensor: + if batch["hist_pos"].ndim == 3: + state_all = trajdataBatchUtils.batch_to_raw_all_agents(batch, step_time) + else: + state_all = batch + bs, na = state_all["curr_speed"].shape[:2] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, na, 6).to( + state_all["curr_speed"].device + ) # [x, y, yaw, vel, dh, veh_len] + current_states[:, :, :2] = state_all["hist_pos"][:, :, -1] + current_states[:, :, 3] = state_all["curr_speed"].abs() + current_states[:, :, [4]] = ( + state_all["hist_yaw"][:, :, -1] - state_all["hist_yaw"][:, :, 1] + ).abs() + current_states[:, :, 5] = state_all["extent"][:, :, -1] # [veh_len] + else: + current_states = torch.zeros(bs, na, 4).to( + state_all["curr_speed"].device + ) # [x, y, vel, yaw] + current_states[:, :, :2] = state_all["hist_pos"][:, :, -1] + current_states[:, :, 2] = state_all["curr_speed"] + current_states[:, :, 3:] = state_all["hist_yaw"][:, :, -1] + return current_states + + @staticmethod + def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raise NotImplementedError + + @staticmethod + def generate_edges(raw_type, extents, pos_pred, yaw_pred, batch_first=False): + return av_utils.generate_edges( + raw_type, extents, pos_pred, yaw_pred, batch_first=batch_first + ) + + @staticmethod + def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ): + return av_utils.gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ) + + @staticmethod + def gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask=None, + ): + return av_utils.gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask, + ) + + @staticmethod + def get_drivable_region_map(rasterized_map): + return av_utils.get_drivable_region_map(rasterized_map) + + def get_modality_shapes(self, cfg: ExperimentConfig): + return av_utils.get_modality_shapes(cfg, rasterize_mode=self.rasterize_mode) diff --git a/diffstack/utils/bezier_utils.py b/diffstack/utils/bezier_utils.py new file mode 100644 index 0000000..cae9ddd --- /dev/null +++ b/diffstack/utils/bezier_utils.py @@ -0,0 +1,97 @@ +import numpy as np +import math + +from scipy.special import comb, gamma + +def bezier_base(n,i,s): + b=math.factorial(n)/math.factorial(i)/math.factorial(n-i)*s**i*(1-s)**(n-i) + return b + + +def bezier_integral(alpha): + n=len(alpha)-1 + m=0 + for i in range(n+1): + m = m + alpha(1+i) * comb(n,i)*gamma(1+i)*gamma(n+1-i)/gamma(n+2) + return m + +def symbolic_dbezier(n): +# alpha_df=A*alpha_f + A = np.zeros([n,n+1]) + for i in range(0,n): + A[i,i]=-n + A[i,i+1]=n + return A + +class bezier_regressor(object): + def __init__(self, N=6, Nmin=8, Nmax=200,Gamma=1e-2): + self.N = N + self.dB = symbolic_dbezier(N) + self.ddB = symbolic_dbezier(N-1) + self.Nmin = Nmin + self.Nmax = Nmax + self.Bez_matr = dict() + self.bez_reg = dict() + self.Bez_matr_der = dict() + self.Bez_matr_dder = dict() + for i in range(Nmin,Nmax+1): + t=np.linspace(0,1,i) + self.Bez_matr[i]=np.zeros([i,self.N+1]) + self.Bez_matr_der[i]=np.zeros([i,self.N]) + self.Bez_matr_dder[i]=np.zeros([i,self.N-1]) + for j in range(i): + for k in range(self.N+1): + self.Bez_matr[i][j,k]=comb(self.N,k)*(1-t[j])**(self.N-k)*t[j]**k + if k DictConfig: + fullpath = Path(file_name) + config_path = Path(diffstack.__path__[0]).parent / "config" + config_name = ( + fullpath.absolute().relative_to(config_path.absolute()).with_suffix("") + ) + # Need it to be relative to this file, not the working directory. + rel_config_path = os.path.relpath(config_path, Path(__file__).parent) + + # Initialize configuration management system + hydra.core.global_hydra.GlobalHydra.instance().clear() # reinitialize hydra if already initialized + hydra.initialize(config_path=str(rel_config_path)) + cfg = hydra.compose(config_name=str(config_name), overrides=[]) + + return cfg + + +####### Functions below are related to tbsim style config classes + + +def recursive_update(config: Dict, external_dict: dict): + leftover = list() + for k, v in external_dict.items(): + if k in config and isinstance(config[k], dict): + config[k] = recursive_update(config[k], v) + else: + leftover.append(k) + + leftover_dict = {k: external_dict[k] for k in leftover} + config.update(**leftover_dict) + return config + + +def recursive_update_flat(config: Dict, external_dict: dict): + left_over = dict() + for k, v in external_dict.items(): + assert not isinstance(v, dict) + if k in config: + assert not isinstance(config[k], dict) + config[k] = v + else: + left_over[k] = v + if len(left_over) > 0: + for k, v in config.items(): + if isinstance(v, dict): + config[k], leftover = recursive_update_flat(v, left_over) + return config, left_over + + +def get_experiment_config_from_file(file_path, locked=False): + ext_cfg = json.load(open(file_path, "r")) + cfg = get_registered_experiment_config(ext_cfg["registered_name"]) + cfg = recursive_update(cfg, ext_cfg) + cfg.lock(locked) + return cfg + + +def translate_trajdata_cfg(cfg: ExperimentConfig): + rcfg = Dict() + # assert cfg.stack.step_time == 0.5 # TODO: support interpolation + + if ( + "predictor" in cfg.stack + and "scene_centric" in cfg.stack["predictor"] + and cfg.stack["predictor"].scene_centric + ): + rcfg.centric = "scene" + else: + rcfg.centric = "agent" + if "standardize_data" in cfg.env.data_generation_params: + rcfg.standardize_data = cfg.env.data_generation_params.standardize_data + else: + rcfg.standardize_data = True + if "predictor" in cfg.stack: + rcfg.step_time = cfg.stack["predictor"].step_time + rcfg.history_num_frames = cfg.stack["predictor"].history_num_frames + rcfg.future_num_frames = cfg.stack["predictor"].future_num_frames + elif "planner" in cfg.stack: + rcfg.step_time = cfg.stack["planner"].step_time + rcfg.history_num_frames = cfg.stack["planner"].history_num_frames + rcfg.future_num_frames = cfg.stack["planner"].future_num_frames + if "remove_parked" in cfg.env: + rcfg.remove_parked = cfg.env.remove_parked + rcfg.trajdata_source_root = cfg.train.trajdata_source_root + rcfg.trajdata_val_source_root = cfg.train.trajdata_val_source_root + rcfg.trajdata_source_train = cfg.train.trajdata_source_train + rcfg.trajdata_source_valid = cfg.train.trajdata_source_valid + rcfg.trajdata_source_test = cfg.train.trajdata_source_test + rcfg.trajdata_test_source_root = cfg.train.trajdata_test_source_root + rcfg.dataset_path = cfg.train.dataset_path + + rcfg.max_agents_distance = cfg.env.data_generation_params.max_agents_distance + rcfg.num_other_agents = cfg.env.data_generation_params.other_agents_num + rcfg.max_agents_distance_simulation = cfg.env.simulation.distance_th_close + rcfg.pixel_size = cfg.env.rasterizer.pixel_size + rcfg.raster_size = int(cfg.env.rasterizer.raster_size) + rcfg.raster_center = cfg.env.rasterizer.ego_center + rcfg.yaw_correction_speed = cfg.env.data_generation_params.yaw_correction_speed + rcfg.incl_neighbor_map = cfg.env.incl_neighbor_map + rcfg.incl_vector_map = cfg.env.incl_vector_map + rcfg.incl_raster_map = cfg.env.incl_raster_map + rcfg.calc_lane_graph = cfg.env.calc_lane_graph + rcfg.other_agents_num = cfg.env.data_generation_params.other_agents_num + rcfg.max_num_lanes = cfg.env.get("max_num_lanes", 15) + rcfg.remove_single_successor = cfg.env.get("remove_single_successor", False) + rcfg.num_lane_pts = cfg.env.get("num_lane_pts", 20) + if "vectorize_lane" in cfg.env.data_generation_params: + rcfg.vectorize_lane = cfg.env.data_generation_params.vectorize_lane + + else: + rcfg.vectorize_lane = "None" + + rcfg.lock() + return rcfg + + +def boolean_string(s): + if s not in {"False", "True", "0", "1", "false", "true", "on", "off"}: + raise ValueError("Not a valid boolean string") + return s in ["True", "1", "true", "on"] diff --git a/diffstack/utils/diffusion_utils.py b/diffstack/utils/diffusion_utils.py new file mode 100644 index 0000000..bfa70e1 --- /dev/null +++ b/diffstack/utils/diffusion_utils.py @@ -0,0 +1,256 @@ +""" +Various utilities for neural networks. +""" + +import math + +import torch as torch +import torch.nn as nn +import numpy as np + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor, mask=None): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(8, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def mean_flat(tensor, mask=None): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def mean_sum(tensor, mask): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return (tensor*mask).sum()/mask.sum().clip(min=1) + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffstack/utils/dist_utils.py b/diffstack/utils/dist_utils.py new file mode 100644 index 0000000..b660507 --- /dev/null +++ b/diffstack/utils/dist_utils.py @@ -0,0 +1,680 @@ +import torch +from torch.nn.functional import one_hot, gumbel_softmax +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.dynamics.base import Dynamics +from diffstack.dynamics.unicycle import Unicycle +import abc +import numpy as np +import diffstack.utils.geometry_utils as GeoUtils + + +def categorical_sample(pi, num_sample): + bs, M = pi.shape + pi0 = torch.cat([torch.zeros_like(pi[:, :1]), pi[:, :-1]], -1) + pi_cumsum = pi0.cumsum(-1) + rand_num = torch.rand([bs, num_sample], device=pi.device) + flag = rand_num.unsqueeze(-1) > pi_cumsum.unsqueeze(-2) + idx = torch.argmax((flag * torch.arange(0, M, device=pi.device)[None, None]), -1) + return idx + + +def categorical_psample_wor( + logpi, num_sample, num_factor=None, factor_mask=None, relevance_score=None +): + """pseodo-sample (pick the most probable modes) from a categorical distribution without replacement + the distribution is assumed to be a factorized categorical distribution + + Args: + logpi (probability): [B,M,D], M is the number of factors, D is the number of categories + num_sample (_type_): number of samples + num_factor (int): If not None, pick the top num_factor factors with the highest probability variation + + """ + B, M, D = logpi.shape + if num_factor is None: + num_factor = M + # assert D**num_factor>=num_sample + if num_factor != M: + if relevance_score is None: + # default to use the maximum mode probability as measure + relevance_score = -logpi.max(-1)[0] + if factor_mask is not None: + relevance_score.masked_fill_( + torch.logical_not(factor_mask), relevance_score.min() - 1 + ) + idx = torch.topk(relevance_score, num_factor, dim=1)[1] + else: + idx = torch.arange(M, device=logpi.device)[None, :].repeat_interleave(B, 0) + factor_logpi = torch.gather( + logpi, 1, idx.unsqueeze(-1).repeat_interleave(D, -1) + ) # Factors are chosen + # calculate the product log probability of the factors + num_factor = factor_logpi.shape[1] + + prod_logpi = sum( + [ + factor_logpi[:, i].reshape(B, *[1] * i, D, *[1] * (num_factor - i - 1)) + for i in range(num_factor) + ] + ) # D**num_factor + + prod_logpi_flat = prod_logpi.view(B, -1) + factor_samples = torch.topk(prod_logpi_flat, num_sample, 1)[1] + factor_sample_idx = list() + for i in range(num_factor): + factor_sample_idx.append(factor_samples % D) + factor_samples = torch.div(factor_samples, D, rounding_mode="floor") + factor_sample_idx = torch.stack(factor_sample_idx, -1).flip(-1) + # for unselected factors, pick the maximum likelihood mode + samples = logpi.argmax(-1).unsqueeze(1).repeat_interleave(num_sample, 1) + samples.scatter_( + 2, idx.unsqueeze(1).repeat_interleave(num_sample, 1), factor_sample_idx + ) + + return samples, idx + + # nonfactor_pi = torch.gather(pi,1,nonidx.unsqueeze(-1).repeat_interleave(D,-1)) + + +class BaseDist(abc.ABC): + @abc.abstractmethod + def rsample(self, num_sample): + pass + + @abc.abstractmethod + def get_dist(self): + pass + + @abc.abstractmethod + def detach_(self): + pass + + @abc.abstractmethod + def index_select_(self, idx): + pass + + +class MAGaussian(BaseDist): + def __init__(self, mu, var, K, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution + + Args: + mu (torch.Tensor): [B,N,D] mean + var (torch.Tensor): [B,N,D] variance + K (torch.Tensor): [B,N,D,L] coefficient of shared scene variance + var = var + K @ K.T + """ + self.mu = mu + self.var = var + self.K = K + self.L = K.shape[-1] + self.min_var = min_var + self.delta_x_clamp = delta_x_clamp + + def rsample(self, num_sample): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, N, D = self.mu.shape + eps = torch.randn(B, num_sample, N, D, device=self.mu.device) + scene_eps = torch.randn(B, num_sample, 1, self.L, device=self.mu.device) + return ( + self.mu.unsqueeze(-3) + + eps * self.var.sqrt().unsqueeze(-3) + + (self.K.unsqueeze(1) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + + def get_dist(self, mask): + B, N, D = self.mu.shape + + var_diag = TensorUtils.block_diag_from_cat( + torch.diag_embed(self.var + self.min_var) + ) + var_inv_diag = torch.linalg.pinv(var_diag) + C = torch.eye(self.K.shape[-1], device=self.K.device)[None].repeat_interleave( + B, 0 + ) + K = self.K.reshape(B, N * D, self.L) + K_T = K.transpose(-1, -2) + var = var_diag + K @ K_T + if self.L > N * D: + var_inv = ( + var_inv_diag + - var_inv_diag + @ K + @ torch.linalg.pinv(C + K_T @ var_inv_diag @ K) + @ K_T + @ var_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return self.mu, var, var_inv + + def get_log_likelihood(self, xp, mask): + # not tested + B = self.mu.shape[0] + mu, var, var_inv = self.get_dist(mask) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu - self.delta_x_clamp), mu + self.delta_x_clamp + ) + delta_x = xp - mu + delta_x *= mask + delta_x = delta_x.reshape(B, -1) + mask_var = torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.mu.shape[-1], -1) + ) + log_prob = 0.5 * ( + torch.logdet(var_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_inv @ delta_x.unsqueeze(-1)).flatten() + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + return log_prob + + def detach_(self): + self.mu = self.mu.detach() + self.var = self.var.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu = self.mu[idx] + self.var = self.varCategorical[idx] + self.K = self.K[idx] + + +class MADynGaussian(BaseDist): + def __init__(self, mu_u, var_u, var_x, K, dyn, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution with dynamics, input follows MAGaussian + + Args: + mu_u (torch.Tensor): [B,N,udim] mean + var_u (torch.Tensor): [B,N,udim] independent variance of u + var_x (torch.Tensor): [B,N,xdim] independent variance of x + K (torch.Tensor): [B,N,D,L] coefficient of shared scene variance + dyn (Dynamics): dynamics object + var = var + K @ K.T + """ + self.mu_u = mu_u + self.var_u = var_u + self.var_x = var_x + self.K = K + self.L = K.shape[-1] + self.dyn = dyn + self.min_var = min_var + self.delta_x_clamp = delta_x_clamp + + def rsample(self, x0, num_sample): + """sample from the distribution + + Args: + x0 (torch.Tensor): [B,N,xdim] initial state + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,xdim] sample + """ + B, N, D = self.mu_u.shape + eps = torch.randn(B, num_sample, N, D, device=self.mu_u.device) + scene_eps = torch.randn(B, num_sample, 1, self.L, device=self.mu_u.device) + u_sample = ( + self.mu_u.unsqueeze(1) + + eps * self.var_u.sqrt().unsqueeze(1) + + (self.K.unsqueeze(1) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + x_sample = self.dyn.step( + x0.unsqueeze(1).repeat_interleave(num_sample, 1), u_sample + ) + x_sample += self.var_x.sqrt().unsqueeze(1) * torch.randn_like(x_sample) + return x_sample + + def get_dist(self, x0, mask): + B, N, D = self.mu_u.shape + + mu_x, _, jacu = self.dyn.step(x0, self.mu_u, bound=False, return_jacobian=True) + if mask is not None: + jacu *= mask.unsqueeze(-1) + # var_x_total = J@var_u@J.T+var_x + var_x = jacu @ torch.diag_embed(self.var_u) @ jacu.transpose( + -1, -2 + ) + torch.diag_embed(self.var_x + self.min_var) + + var_x_inv = torch.linalg.pinv(var_x) + var_x_diag = TensorUtils.block_diag_from_cat(var_x) + var_x_inv_diag = TensorUtils.block_diag_from_cat(var_x_inv) + blk_jacu = TensorUtils.block_diag_from_cat(jacu) + K = self.K.reshape(B, N * D, self.L) + + JK = blk_jacu @ K + JK_T = JK.transpose(-1, -2) + C = torch.eye(self.K.shape[-1], device=self.K.device)[None].repeat_interleave( + B, 0 + ) + + var = var_x_diag + JK @ JK_T + + if self.L > N * D: + # Woodbury matrix identity + var_inv = ( + var_x_inv_diag + - var_x_inv_diag + @ JK + @ torch.linalg.pinv(C + JK_T @ var_x_inv_diag @ JK) + @ JK_T + @ var_x_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return mu_x, var, var_inv + + def get_log_likelihood(self, x0, xp, mask): + B, N = self.mu_u.shape[:2] + mu_x, total_var_x, var_x_inv = self.get_dist(x0, mask) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu_x - self.delta_x_clamp), mu_x + self.delta_x_clamp + ) + # # ground truth + delta_x = xp - mu_x + delta_x *= mask + mask_var = torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.dyn.xdim, -1) + ) + # hack: write delta function for each dynamics + delta_x[..., 3] = GeoUtils.round_2pi(delta_x[..., 3]) + delta_x = delta_x.reshape(B, -1) + log_prob = 0.5 * ( + torch.logdet(var_x_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_x_inv @ delta_x.unsqueeze(-1)).flatten() + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + return log_prob + + def detach_(self): + self.mu_u = self.mu_u.detach() + self.var_u = self.var_u.detach() + self.var_x = self.var_x.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu_u = self.mu_u[idx] + self.var_u = self.var_u[idx] + self.var_x = self.var_x[idx] + self.K = self.K[idx] + + +class MAGMM(BaseDist): + def __init__(self, mu, var, K, pi, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution + + Args: + mu (torch.Tensor): [B,M,N,D] mean + var (torch.Tensor): [B,M,N,D] variance + K (torch.Tensor): [B,M,N,D,L] coefficient of shared scene variance + var = var + K @ K.T + pi: [B,M] mixture weights + """ + self.mu = mu + self.var = var + self.K = K + self.L = K.shape[-1] + self.M = mu.shape[1] + self.pi = pi + self.logits = torch.log(pi) + self.pi_sum = pi.cumsum(-1) + self.delta_x_clamp = delta_x_clamp + self.min_var = min_var + + def rsample(self, num_sample, tau=None, infer=False): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, M, N, D = self.mu.shape + if tau is not None: + mode = gumbel_softmax( + self.logits[:, None].repeat_interleave(num_sample, 1), tau, hard=infer + ) + else: + mode = categorical_sample(self.pi, num_sample) + mode = one_hot(mode, self.M) + eps = torch.randn(B, M, num_sample, N, D, device=self.mu.device) + scene_eps = torch.randn(B, M, num_sample, 1, self.L, device=self.mu.device) + sample = ( + self.mu.unsqueeze(-3) + + eps * self.var.sqrt().unsqueeze(-3) + + (self.K.unsqueeze(2) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + sample = (sample * mode.transpose(1, 2).view(B, M, num_sample, 1, 1)).sum( + 1 + ) # B x num_sample x N x D + return sample + + def get_dist(self, mask): + B, M, N, D = self.mu.shape + var_tiled, mu_tiled, K_tiled = TensorUtils.join_dimensions( + (self.var, self.mu, self.K), 0, 2 + ) + var_diag = TensorUtils.block_diag_from_cat( + torch.diag_embed(var_tiled + self.min_var) + ) + var_inv_diag = torch.linalg.pinv(var_diag) + C = torch.eye(K_tiled.shape[-1], device=self.K.device)[None].repeat_interleave( + B * M, 0 + ) + K = K_tiled.reshape(B * M, N * D, self.L) + K_T = K.transpose(-1, -2) + var = var_diag + K @ K_T + if self.L > N * D: + var_inv = ( + var_inv_diag + - var_inv_diag + @ K + @ torch.linalg.pinv(C + K_T @ var_inv_diag @ K) + @ K_T + @ var_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return ( + self.mu, + var.reshape(B, M, N * D, N * D), + var_inv.reshape(B, M, N * D, N * D), + self.pi, + ) + + def get_log_likelihood(self, xp, mask): + # not tested + B, M = self.mu.shape[:2] + mu, var, var_inv, pi = self.get_dist(mask) + xp = xp[:, None].repeat_interleave(M, 1) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu - self.delta_x_clamp), mu + self.delta_x_clamp + ) + delta_x = xp - mu + + delta_x *= mask.unsqueeze(1) + delta_x = delta_x.reshape(B, M, -1) + mask_var = ( + torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.mu.shape[-1], -1) + ) + .unsqueeze(1) + .repeat_interleave(self.M, 1) + ) + log_prob_mode = 0.5 * ( + torch.logdet(var_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_inv @ delta_x.unsqueeze(-1)).reshape(B, M) + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + log_prob = torch.log((torch.exp(log_prob_mode) * pi).sum(-1)) + return log_prob + + def detach_(self): + self.mu = self.mu.detach() + self.var = self.var.detach() + self.K = self.K.detach() + self.pi = self.pi.detach() + + def index_select_(self, idx): + self.mu = self.mu[idx] + self.var = self.var[idx] + self.K = self.K[idx] + self.pi = self.pi[idx] + + +class MADynGMM(BaseDist): + def __init__(self, mu_u, var_u, var_x, K, pi, dyn, delta_x_clamp=1.0, min_var=1e-6): + """multiagent gaussian distribution with dynamics, input follows MAGaussian + + Args: + mu_u (torch.Tensor): [B,M,N,udim] mean + var_u (torch.Tensor): [B,M,N,udim] independent variance of u + var_x (torch.Tensor): [B,M,N,xdim] independent variance of x + K (torch.Tensor): [B,M,N,D,L] coefficient of shared scene variance + pi (torch.Tensor): [B,M] mixture weights + dyn (Dynamics): dynamics object + var = var + K @ K.T + """ + self.mu_u = mu_u + self.var_u = var_u + self.var_x = var_x + self.K = K + self.L = K.shape[-1] + self.M = mu_u.shape[1] + self.pi = pi + self.logits = torch.log(pi) + self.dyn = dyn + self.delta_x_clamp = delta_x_clamp + self.min_var = min_var + + def rsample(self, x0, num_sample, tau=None, infer=False): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, M, N, udim = self.mu_u.shape + if tau is not None: + mode = gumbel_softmax( + self.logits[:, None].repeat_interleave(num_sample, 1), tau, hard=infer + ) + else: + mode = categorical_sample(self.pi, num_sample) + mode = one_hot(mode, self.M) + eps = torch.randn(B, M, num_sample, N, udim, device=self.mu_u.device) + scene_eps = torch.randn(B, M, num_sample, 1, self.L, device=self.mu_u.device) + u_sample = ( + self.mu_u.unsqueeze(2) + + eps * self.var_u.sqrt().unsqueeze(2) + + (self.K.unsqueeze(2) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + x_sample = self.dyn.step( + x0[:, None, None] + .repeat_interleave(self.M, 1) + .repeat_interleave(num_sample, 2), + u_sample, + ) + x_sample += self.var_x.sqrt().unsqueeze(2) * torch.randn_like(x_sample) + x_sample = (x_sample * (mode.transpose(1, 2).view(B, M, num_sample, 1, 1))).sum( + 1 + ) + return x_sample + + def get_dist(self, x0, mask): + B, M, N, udim = self.mu_u.shape + var_x_tiled, var_u_tiled, mu_u_tiled, K_tiled = TensorUtils.join_dimensions( + (self.var_x, self.var_u, self.mu_u, self.K), 0, 2 + ) + x0_tiled = x0.repeat_interleave(M, 0) + mu_x, _, jacu = self.dyn.step( + x0_tiled, mu_u_tiled, bound=False, return_jacobian=True + ) + # var_x_total = J@var_u@J.T+var_x + var_x = jacu @ torch.diag_embed(var_u_tiled) @ jacu.transpose( + -1, -2 + ) + torch.diag_embed(var_x_tiled + self.min_var) + + var_x_inv = torch.linalg.pinv(var_x) + var_x_diag = TensorUtils.block_diag_from_cat(var_x) + var_x_inv_diag = TensorUtils.block_diag_from_cat(var_x_inv) + blk_jacu = TensorUtils.block_diag_from_cat(jacu) + K = K_tiled.reshape(B * M, N * udim, self.L) + JK = blk_jacu @ K + JK_T = JK.transpose(-1, -2) + C = torch.eye(K_tiled.shape[-1], device=self.K.device)[None].repeat_interleave( + B * M, 0 + ) + + var = var_x_diag + JK @ JK_T + if self.L > N * udim: + # Woodbury matrix identity + var_inv = ( + var_x_inv_diag + - var_x_inv_diag + @ JK + @ torch.linalg.pinv(C + JK_T @ var_x_inv_diag @ JK) + @ JK_T + @ var_x_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return ( + *TensorUtils.reshape_dimensions((mu_x, var, var_inv), 0, 1, (B, M)), + self.pi, + ) + + def get_log_likelihood(self, x0, xp, mask): + B, M, N = self.mu_u.shape[:3] + mu_x, total_var_x, var_x_inv, pi = self.get_dist(x0, mask) + + xp = xp[:, None].repeat_interleave(M, 1) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu_x - self.delta_x_clamp), mu_x + self.delta_x_clamp + ) + # # ground truth + delta_x = xp - mu_x + + delta_x *= mask.unsqueeze(1) + delta_x = delta_x.reshape(B, M, -1) + mask_var = ( + torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.dyn.xdim, -1) + ) + .unsqueeze(1) + .repeat_interleave(self.M, 1) + ) + log_prob_mode = 0.5 * ( + torch.logdet(var_x_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_x_inv @ delta_x.unsqueeze(-1)).reshape(B, M) + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + # log_prob = torch.log((torch.exp(log_prob_mode)*pi).sum(-1)) + # Apply EM: + log_prob = (log_prob_mode * pi).sum(-1) + return log_prob + + def detach_(self): + self.mu_u = self.mu_u.detach() + self.var_u = self.var_u.detach() + self.var_x = self.var_x.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu_u = self.mu_u[idx] + self.var_u = self.var_u[idx] + self.var_x = self.var_x[idx] + self.K = self.K[idx] + + +# class MADynEBM(BaseDist): +# """Multiagent Energy-based model with dynamics. + +# """ +# def __init__(self,dyn) + + +def test_ma(): + bs = 5 + N = 3 + d = 2 + L = 10 + mu = torch.randn(bs, N, d) + var = torch.randn(bs, N, d) ** 2 + K = torch.randn(bs, N, d, L) + dist = MAGaussian(mu, var, K) + sample = dist.rsample(15) + var_inv = dist.get_dist() + print("done") + + +def test_ma_dyn(): + bs = 5 + N = 3 + udim = 2 + xdim = 4 + L = 10 + dt = 0.1 + mu_u = torch.randn(bs, N, udim) + var_u = torch.randn(bs, N, udim) ** 2 + var_x = torch.randn(bs, N, xdim) ** 2 + K = torch.randn(bs, N, udim, L) + dyn = Unicycle(dt) + dist = MADynGaussian(mu_u, var_u, var_x, K, dyn) + x0 = torch.randn([bs, N, dyn.xdim]) + sample = dist.rsample(x0, 15) + var_inv = dist.get_dist(x0) + + +def test_categorical(): + pi = torch.tensor([0.1, 0.3, 0.6])[None].repeat_interleave(6, 0) + categorical_sample(pi, 10) + + +def test_GMM(): + D = 2 + B = 10 + M = 3 + N = 5 + L = 16 + + mu = torch.randn([B, M, N, D]) + var = torch.randn([B, M, N, D]) ** 2 + K = torch.randn([B, M, N, D, L]) + pi = torch.rand([B, M]) + pi = pi / pi.sum(-1, keepdim=True) + dist = MAGMM(mu, var, K, pi) + sample = dist.rsample(10) + _ = dist.get_dist() + x = torch.randn([B, N, D]) + log_prob = dist.get_log_likelihood(x) + + +def test_dyn_GMM(): + udim = 2 + xdim = 4 + dt = 0.1 + B = 10 + M = 3 + N = 5 + L = 16 + dyn = Unicycle(dt) + + mu_u = torch.randn(B, M, N, udim) + var_u = torch.randn(B, M, N, udim) ** 2 + var_x = torch.randn(B, M, N, xdim) ** 2 + K = torch.randn([B, M, N, udim, L]) + pi = torch.rand([B, M]) + pi = pi / pi.sum(-1, keepdim=True) + dist = MADynGMM(mu_u, var_u, var_x, K, pi, dyn) + + x = torch.randn([B, N, xdim]) + xp = torch.randn([B, N, xdim]) + _ = dist.get_dist(x) + sample = dist.rsample(x, 10, tau=0.5) + log_prob = dist.get_log_likelihood(x, xp) + print("done") + + +def test_sample_wor(): + pi = torch.randn([5, 6, 7]) + pi = pi**2 + pi = pi / pi.sum(-1, keepdim=True) + sample = categorical_sample_wor(pi, 20, 3) + + +if __name__ == "__main__": + test_sample_wor() diff --git a/diffstack/utils/env_utils.py b/diffstack/utils/env_utils.py new file mode 100644 index 0000000..90f2a35 --- /dev/null +++ b/diffstack/utils/env_utils.py @@ -0,0 +1,478 @@ +from typing import OrderedDict +import numpy as np +import pytorch_lightning as pl +import torch +import importlib +import os +from imageio import get_writer + +from diffstack.sim.TBSIM.base import BatchedEnv, BaseEnv +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.timer import Timers +from diffstack.sim.env_builders import EnvNuscBuilder +from diffstack.policies.wrappers import RolloutWrapper, Pos2YawWrapper +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.trajdata_utils import parse_trajdata_batch +from diffstack.utils.geometry_utils import VEH_VEH_collision +from trajdata.simulation import SimulationScene +import random + + +def collision_check(agents_posyaw, new_posyaw, agents_extent, new_extent): + new_posyaw_tiled = new_posyaw[np.newaxis, :].repeat(agents_posyaw.shape[0], 0) + new_extent_tiled = new_extent[np.newaxis, :].repeat(agents_posyaw.shape[0], 0) + dis = VEH_VEH_collision( + new_posyaw_tiled, agents_posyaw, new_extent_tiled, agents_extent + ) + return dis + + +def random_placing_neighbors(simscene, num_neighbors, coll_check=True): + init_modes = [0, 1, 2, 3, 4] + random.shuffle(init_modes) + init_modes = init_modes[:num_neighbors] + offset_x = 18.0 + offset_y = 5.0 + T = 10 + v_sigma = 0.3 + + dt = simscene.scene.dt + obs = simscene.get_obs() + if isinstance(simscene, SimulationScene): + obs = parse_trajdata_batch(obs) + + obs = TensorUtils.to_numpy(obs, ignore_if_unspecified=True) + num_new_agent = 0 + agent_names = [agent.name for agent in simscene.agents] + while "agent" + str(num_new_agent) in agent_names: + num_new_agent += 1 + ego_vel = obs["curr_speed"][0] + agent_plan = list() + for i in range(num_neighbors): + newagent_name = "agent" + str(num_new_agent + i) + newagent_type = 1 + newagent_extent = np.array([4.0, 2.5, 2.0]) + if init_modes[i] == 0: + # in front of the ego vehicle + newagent_state = np.array([[offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel - 2.0 + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 1: + # behind of the ego vehicle + newagent_state = np.array([[-offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + 2.0 + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 2: + # left of the ego vehicle + newagent_state = np.array([[0, -offset_y, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 3: + # right of the ego vehicle + newagent_state = np.array([[0, offset_y, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 4: + # two vehicle length ahead of the ego vehicle + newagent_state = np.array([[2 * offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel - 4.0 + np.random.randn() * v_sigma, 0.0, 40.0) + + newagent_state[:, 0] += np.arange(-T + 1, 1) * dt * vel + + add_flag = True + new_pos_global = GeoUtils.batch_nd_transform_points_np( + newagent_state[:, :2], obs["world_from_agent"][0] + ) + new_yaw_global = newagent_state[:, 2:] + obs["world_yaw"][0] + newagent_state_global = np.hstack((new_pos_global, new_yaw_global)) + if coll_check: + if "centroid" in obs: + agents_pos_global = obs["centroid"] + else: + agents_pos_global = GeoUtils.batch_nd_transform_points_np( + obs["history_positions"][:, -1], obs["world_from_agent"] + ) + agents_yaw_global = obs["world_yaw"] + agents_posyaw = np.hstack( + (agents_pos_global, agents_yaw_global[:, np.newaxis]) + ) + new_posyaw = np.hstack((new_pos_global[0], new_yaw_global[0])) + + dis = collision_check( + agents_posyaw, new_posyaw, obs["extent"], newagent_extent + ) + if dis.min() < 2.0: + add_flag = False + + if add_flag: + agent_plan.append( + dict( + name=newagent_name, + agent_state=newagent_state_global.tolist(), + initial_timestep=simscene.scene_ts - T + 1, + agent_type=newagent_type, + extent=newagent_extent.tolist(), + executed=False, + ) + ) + return agent_plan + + +def random_initial_adjust_plan(env, adjust_recipe): + adjust_plan = dict() + for simscene in env._current_scenes: + adjust_plan[simscene.scene.name] = dict( + remove_existing_neighbors=dict( + flag=adjust_recipe["remove_existing_neighbors"], executed=False + ), + agents=random_placing_neighbors( + simscene, adjust_recipe["initial_num_neighbors"] + ), + ) + + return adjust_plan + + +def rollout_episodes( + env, + policy, + num_episodes, + skip_first_n=1, + n_step_action=1, + render=False, + scene_indices=None, + start_frame_index_each_episode=None, + device=None, + obs_to_torch=True, + adjust_plan_recipe=None, + horizon=None, + seed_each_episode=None, +): + """ + Rollout an environment for a number of episodes + Args: + env (BaseEnv): a base simulation environment (gym-like) + policy (RolloutWrapper): a policy that controls agents in the environment + num_episodes (int): number of episodes to rollout for + skip_first_n (int): number of steps to skip at the begining + n_step_action (int): number of steps to take between querying models + render (bool): if True, return a sequence of rendered frames + scene_indices (tuple, list): (Optional) scenes indices to rollout with + start_frame_index_each_episode (List): (Optional) which frame to start each simulation episode from, + device: device to cast observation to + obs_to_torch: whether to cast observation to torch + adjust_plan_recipe (dict): (Optional) initialization condition, either a fixed plan or a recipe for random generation + horizon (int): (Optional) override horizon of the simulation + seed_each_episode (List): (Optional) a list of seeds, one for each episode + + Returns: + stats (dict): A dictionary of rollout stats for each episode (metrics, rewards, etc.) + info (dict): A dictionary of environment info for each episode + renderings (list): A list of rendered frames in the form of np.ndarray, one for each episode + """ + stats = {} + info = {} + renderings = [] + is_batched_env = isinstance(env, BatchedEnv) + timers = Timers() + adjust_plans = list() + if seed_each_episode is not None: + assert len(seed_each_episode) == num_episodes + if start_frame_index_each_episode is not None: + assert len(start_frame_index_each_episode) == num_episodes + + ego_policy = policy.unwrap()["Rollout.ego_policy"] + trace = list() + for ei in range(num_episodes): + if start_frame_index_each_episode is not None: + start_frame_index = start_frame_index_each_episode[ei] + else: + start_frame_index = None + + env.reset(scene_indices=scene_indices, start_frame_index=start_frame_index) + if adjust_plan_recipe is not None: + if "random_init_plan" in adjust_plan_recipe: + # recipe provided + if adjust_plan_recipe["random_init_plan"]: + adjust_recipe = adjust_plan_recipe + adjust_plan = random_initial_adjust_plan(env, adjust_recipe) + + else: + adjust_plan = None + adjust_recipe = None + + else: + # explicit plan provided + adjust_plan = adjust_plan_recipe + else: + adjust_plan = None + adjust_recipe = None + if adjust_plan is not None: + env.adjust_scene(adjust_plan) + + if seed_each_episode is not None: + env.update_random_seed(seed_each_episode[ei]) + + done = env.is_done() + counter = 0 + step_since_last_update = 0 + frames = list() + while not done: + if adjust_recipe is not None: + if step_since_last_update > adjust_recipe["num_frame_per_new_agent"]: + for simscene in env._current_scenes: + if simscene.scene_ts < simscene.scene.length_timesteps - 10: + extra_agent = random_placing_neighbors(simscene, 1) + adjust_plan[simscene.scene.name]["agents"] += extra_agent + env.adjust_scene(adjust_plan) + step_since_last_update = 0 + timers.tic("step") + with timers.timed("obs"): + obs = env.get_observation() + with timers.timed("to_torch"): + if obs_to_torch: + device = policy.device if device is None else device + obs_torch = TensorUtils.to_torch( + obs, device=device, ignore_if_unspecified=True + ) + else: + obs_torch = obs + + with timers.timed("network"): + action = policy.get_action(obs_torch, step_index=counter) + + if counter < skip_first_n: + # use GT action for the first N steps to warm up environment state (velocity, etc.) + gt_action = env.get_gt_action(obs) + action.ego = gt_action.ego + action.agents = gt_action.agents + env.step(action, num_steps_to_take=1, render=False) + counter += 1 + step_since_last_update += 1 + else: + with timers.timed("env_step"): + ims = env.step( + action, num_steps_to_take=n_step_action, render=render + ) # List of [num_scene, h, w, 3] + if render: + frames.extend(ims) + counter += n_step_action + step_since_last_update += n_step_action + timers.toc("step") + # print(timers) + + done = env.is_done() + + if horizon is not None and counter >= horizon: + break + metrics = env.get_metrics() + if hasattr(ego_policy, "savetrace") and ego_policy.savetrace: + trace.append(ego_policy.trace.copy()) + + for k, v in metrics.items(): + if k not in stats: + stats[k] = [] + if is_batched_env: # concatenate by scene + stats[k] = np.concatenate([stats[k], v], axis=0) + else: + stats[k].append(v) + + env_info = env.get_info() + for k, v in env_info.items(): + if k not in info: + if isinstance(v, dict): + info[k] = dict() + else: + info[k] = list() + + if is_batched_env: + if isinstance(v, dict): + info[k].update(v) + else: + info[k].extend(v) + else: + info[k].append(v) + del env_info + if hasattr(ego_policy, "reset"): + ego_policy.reset() + if render: + frames = np.stack(frames) + if is_batched_env: + # [step, scene] -> [scene, step] + frames = frames.transpose((1, 0, 2, 3, 4)) + renderings.append(frames) + if adjust_plan is not None: + adjust_plans.append(adjust_plan) + + multi_episodes_metrics = env.get_multi_episode_metrics() + stats.update(multi_episodes_metrics) + env.reset_multi_episodes_metrics() + + return stats, info, renderings, adjust_plans, trace + + +class RolloutCallback(pl.Callback): + """A pytorch-lightning callback function that runs rollouts during training""" + + def __init__( + self, + exp_config, + every_n_steps=100, + warm_start_n_steps=1, + verbose=False, + save_video=False, + video_dir=None, + ): + self._every_n_steps = every_n_steps + self._warm_start_n_steps = warm_start_n_steps + self._verbose = verbose + self._exp_cfg = exp_config.clone() + self._save_video = save_video + self._video_dir = video_dir + self.env = None + self.policy = None + self._eval_cfg = self._exp_cfg.eval + + def print_if_verbose(self, msg): + if self._verbose: + print(msg) + + def _get_env(self, device): + if self.env is not None: + return self.env + if self._eval_cfg.env == "nusc": + env_builder = EnvNuscBuilder( + eval_config=self._eval_cfg, exp_config=self._exp_cfg, device=device + ) + + else: + raise NotImplementedError( + "{} is not a valid env".format(self._eval_cfg.env) + ) + + env = env_builder.get_env() + self.env = env + return self.env + + def _get_policy(self, pl_module: pl.LightningModule): + if self.policy is not None: + return self.policy + policy_composers = importlib.import_module("tbsim.evaluation.policy_composers") + + composer_class = getattr(policy_composers, self._eval_cfg.eval_class) + composer = composer_class( + self._eval_cfg, pl_module.device, ckpt_root_dir=self._eval_cfg.ckpt_root_dir + ) + print("Building composer {}".format(self._eval_cfg.eval_class)) + + if self._exp_cfg.algo.name == "ma_rasterized": + policy, _ = composer.get_policy(predictor=pl_module) + else: + policy, _ = composer.get_policy(policy=pl_module) + + if self._eval_cfg.policy.pos_to_yaw: + policy = Pos2YawWrapper( + policy, + dt=self._exp_cfg.algo.step_time, + yaw_correction_speed=self._eval_cfg.policy.yaw_correction_speed, + ) + + if self._eval_cfg.env == "nusc": + rollout_policy = RolloutWrapper(agents_policy=policy) + elif self._eval_cfg.ego_only: + rollout_policy = RolloutWrapper(ego_policy=policy) + else: + rollout_policy = RolloutWrapper(ego_policy=policy, agents_policy=policy) + + self.policy = rollout_policy + return self.policy + + def _run_rollout(self, pl_module: pl.LightningModule, global_step: int): + rollout_policy = self._get_policy(pl_module) + env = self._get_env(pl_module.device) + + scene_i = 0 + eval_scenes = self._eval_cfg.eval_scenes + + result_stats = None + + while scene_i < len(eval_scenes): + scene_indices = eval_scenes[ + scene_i : scene_i + self._eval_cfg.num_scenes_per_batch + ] + scene_i += self._eval_cfg.num_scenes_per_batch + stats, info, renderings, _, _ = rollout_episodes( + env, + rollout_policy, + num_episodes=self._eval_cfg.num_episode_repeats, + n_step_action=self._eval_cfg.n_step_action, + render=self._save_video, + skip_first_n=self._eval_cfg.skip_first_n, + scene_indices=scene_indices, + start_frame_index_each_episode=self._eval_cfg.start_frame_index_each_episode, + seed_each_episode=self._eval_cfg.seed_each_episode, + horizon=self._eval_cfg.num_simulation_steps, + ) + + if result_stats is None: + result_stats = stats + result_stats["scene_index"] = np.array(info["scene_index"]) + else: + for k in stats: + result_stats[k] = np.concatenate( + [result_stats[k], stats[k]], axis=0 + ) + result_stats["scene_index"] = np.concatenate( + [result_stats["scene_index"], np.array(info["scene_index"])] + ) + + if self._save_video: + for ei, episode_rendering in enumerate(renderings): + for i, scene_images in enumerate(episode_rendering): + video_fn = "{}_{}_{}.mp4".format( + global_step, info["scene_index"][i], ei + ) + + writer = get_writer( + os.path.join(self._video_dir, video_fn), fps=10 + ) + print( + "video to {}".format( + os.path.join(self._video_dir, video_fn) + ) + ) + for im in scene_images: + writer.append_data(im) + writer.close() + return result_stats + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs, + batch, + batch_idx, + unused=0, + ) -> None: + should_run = ( + trainer.global_step >= self._warm_start_n_steps + and trainer.global_step % self._every_n_steps == 0 + ) + if should_run: + try: + self.print_if_verbose( + "\nStep %i rollout (%i episodes): " + % (trainer.global_step, len(self._eval_cfg.eval_scenes)) + ) + + stats = self._run_rollout(pl_module, trainer.global_step) + for k, v in stats.items(): + if "ttf" in k: # avoid cluttering the plot + continue + # Set on_step=True and on_epoch=False to force the logger to log stats at the step + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/9772 for explanation + pl_module.log( + "rollout/" + k, np.mean(v), on_step=True, on_epoch=False + ) + self.print_if_verbose(("rollout/" + k, np.mean(v))) + self.print_if_verbose("\n") + except Exception as e: + print("Rollout failed because:") + print(e) diff --git a/diffstack/utils/experiment_utils.py b/diffstack/utils/experiment_utils.py new file mode 100644 index 0000000..5b8b266 --- /dev/null +++ b/diffstack/utils/experiment_utils.py @@ -0,0 +1,393 @@ +import json +import os +import itertools +from collections import namedtuple +from typing import List +from glob import glob +import subprocess +import shutil +from pathlib import Path + +import diffstack +from diffstack.configs.registry import get_registered_experiment_config +from diffstack.configs.config import Dict +from diffstack.configs.eval_config import EvaluationConfig +from diffstack.configs.base import ExperimentConfig + + +class Param(namedtuple("Param", "config_var alias value")): + pass + + +class ParamRange(namedtuple("Param", "config_var alias range")): + def linearize(self): + return [Param(self.config_var, self.alias, v) for v in self.range] + + def __len__(self): + return len(self.range) + + +class ParamConfig(object): + def __init__(self, params: List[Param] = None): + self.params = [] + self.aliases = [] + self.config_vars = [] + print(params) + if params is not None: + for p in params: + self.add(p) + + def add(self, param: Param): + assert param.config_var not in self.config_vars + assert param.alias not in self.aliases + self.config_vars.append(param.config_var) + self.aliases.append(param.alias) + self.params.append(param) + + def __str__(self): + char_to_remove = [" ", "(", ")", ";", "[", "]"] + name = [] + for p in self.params: + v_str = str(p.value) + for c in char_to_remove: + v_str = v_str.replace(c, "") + name.append(p.alias + v_str) + + return "_".join(name) + + def generate_config(self, base_cfg: Dict): + cfg = base_cfg.clone() + for p in self.params: + var_list = p.config_var.split(".") + c = cfg + # traverse the indexing list + for v in var_list[:-1]: + assert v in c, "{} is not a valid config variable".format(p.config_var) + c = c[v] + assert var_list[-1] in c, "{} is not a valid config variable".format( + p.config_var + ) + c[var_list[-1]] = p.value + cfg.name = str(self) + return cfg + + +class ParamSearchPlan(object): + def __init__(self): + self.param_configs = [] + self.const_params = [] + + def add_const_param(self, param: Param): + self.const_params.append(param) + + def add(self, param_config: ParamConfig): + for c in self.const_params: + param_config.add(c) + self.param_configs.append(param_config) + + def extend(self, param_configs: List[ParamConfig]): + for pc in param_configs: + self.add(pc) + + @staticmethod + def compose_concate(param_ranges: List[ParamRange]): + pcs = [] + for pr in param_ranges: + for p in pr.linearize(): + pcs.append(ParamConfig([p])) + return pcs + + @staticmethod + def compose_cartesian(param_ranges: List[ParamRange]): + """Cartesian product among parameters""" + prs = [pr.linearize() for pr in param_ranges] + return [ParamConfig(pr) for pr in itertools.product(*prs)] + + @staticmethod + def compose_zip(param_ranges: List[ParamRange]): + l = len(param_ranges[0]) + assert all( + len(pr) == l for pr in param_ranges + ), "All param_range must be the same length" + prs = [pr.linearize() for pr in param_ranges] + return [ParamConfig(prz) for prz in zip(*prs)] + + def generate_configs(self, base_cfg: Dict): + """ + Generate configs from the parameter search plan, also rename the experiment by generating the correct alias. + """ + if len(self.param_configs) > 0: + return [pc.generate_config(base_cfg) for pc in self.param_configs] + else: + # constant-only + const_cfg = ParamConfig(self.const_params) + return [const_cfg.generate_config(base_cfg)] + + +def create_configs( + configs_to_search_fn, + config_name, + config_file, + config_dir, + prefix, + delete_config_dir=True, +): + if config_name is not None: + cfg = get_registered_experiment_config(config_name) + print("Generating configs for {}".format(config_name)) + elif config_file is not None: + # Update default config with external json file + ext_cfg = json.load(open(config_file, "r")) + cfg = get_registered_experiment_config(ext_cfg["registered_name"]) + cfg.update(**ext_cfg) + print("Generating configs with {} as template".format(config_file)) + else: + raise FileNotFoundError("No base config is provided") + + configs = configs_to_search_fn(base_cfg=cfg) + for c in configs: + pfx = "{}_".format(prefix) if prefix is not None else "" + c.name = pfx + c.name + config_fns = [] + + if delete_config_dir and os.path.exists(config_dir): + shutil.rmtree(config_dir) + os.makedirs(config_dir, exist_ok=True) + for c in configs: + fn = os.path.join(config_dir, "{}.json".format(c.name)) + config_fns.append(fn) + print("Saving config to {}".format(fn)) + c.dump(fn) + + return configs, config_fns + + +def read_configs(config_dir): + configs = [] + config_fns = [] + for cfn in glob(config_dir + "/*.json"): + print(cfn) + config_fns.append(cfn) + ext_cfg = json.load(open(cfn, "r")) + c = get_registered_experiment_config(ext_cfg["registered_name"]) + c.update(**ext_cfg) + configs.append(c) + return configs, config_fns + + +def create_evaluation_configs( + configs_to_search_fn, + config_dir, + cfg, + prefix=None, + delete_config_dir=True, +): + configs = configs_to_search_fn(base_cfg=cfg) + for c in configs: + if prefix is not None: + c.name = prefix + "_" + c.name + + config_fns = [] + + if delete_config_dir and os.path.exists(config_dir): + shutil.rmtree(config_dir) + os.makedirs(config_dir, exist_ok=True) + for c in configs: + fn = os.path.join(config_dir, "{}.json".format(c.name)) + config_fns.append(fn) + print("Saving config to {}".format(fn)) + c.dump(fn) + + return configs, config_fns + + +def read_evaluation_configs(config_dir): + configs = [] + config_fns = [] + for cfn in glob(config_dir + "/*.json"): + print(cfn) + config_fns.append(cfn) + c = EvaluationConfig() + ext_cfg = json.load(open(cfn, "r")) + c.update(**ext_cfg) + configs.append(c) + return configs, config_fns + + +def upload_codebase_to_ngc_workspace(ngc_config): + """ + Upload local codebase to NGC workspace + Args: + ngc_config (dict): NGC config + + """ + ngc_path = os.path.join(ngc_config["workspace_mounting_point_local"], "diffstack/") + local_path = Path(diffstack.__path__[0]).parent + assert os.path.exists(ngc_path), "please mount NGC path first" + dir_list = ["scripts/", "diffstack/"] + for d in dir_list: + print("uploading {}".format(d)) + shutil.copytree( + os.path.join(local_path, d), os.path.join(ngc_path, d), dirs_exist_ok=True + ) + file_list = ["setup.py"] + for f in file_list: + print("uploading {}".format(f)) + shutil.copy(os.path.join(local_path, f), os.path.join(ngc_path, f)) + + +def launch_experiments_local(script_path, cfgs, cfg_paths, extra_args=[]): + for cfg, cpath in zip(cfgs, cfg_paths): + cmd = ["python", script_path, "--config_file", cpath] + extra_args + subprocess.run(cmd) + + +def get_results_info_ngc(ngc_job_id): + cmd = ["ngc", "result", "info", str(ngc_job_id), "--files"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + outs, errs = process.communicate() + if len(errs) > 0: + print(str(errs)) + return None + outs = str(outs).split("\\n") + ckpt_paths = [l.strip(" ") for l in outs if l.endswith(".ckpt")] + cfg_path = [l.strip(" ") for l in outs if l.endswith(".json")] + assert len(cfg_path) == 1 + cfg_path = cfg_path[0] + job_name = cfg_path.split("/")[1] + return ckpt_paths, cfg_path, job_name + + +def _download_from_ngc(ngc_job_id, paths_to_download, target_dir, tmp_dir="/tmp"): + cmd = ["ngc", "result", "download", str(ngc_job_id)] + print("Downloading: ") + for fp in paths_to_download: + print(fp) + cmd.extend(["--file", fp]) + + cmd.extend(["--dest", tmp_dir]) + if os.path.exists(os.path.join(tmp_dir, ngc_job_id + "/")): + print("tmp folder with ngc job ID exists, removing ...") + shutil.rmtree( + os.path.join(tmp_dir, ngc_job_id + "/") + ) # otherwise ngc renames the downloaded folder + subprocess.run(cmd) + + os.makedirs(target_dir, exist_ok=True) + + for fp in paths_to_download: + src_path = os.path.join(tmp_dir, ngc_job_id + fp) + shutil.move(src_path, target_dir) + + shutil.rmtree(os.path.join(tmp_dir, ngc_job_id + "/")) + + +def download_checkpoints_from_ngc( + ngc_job_id, ckpt_root_dir, ckpt_path_func=None, tmp_dir="/tmp" +): + assert os.path.exists(ckpt_root_dir) + ckpt_paths, cfg_path, job_name = get_results_info_ngc(ngc_job_id) + + if ckpt_path_func is None: + + def ckpt_path_func(x): + return x + + to_download = ckpt_path_func(ckpt_paths) + to_download.append(cfg_path) + ckpt_target_dir = os.path.join(ckpt_root_dir, "{}_{}".format(job_name, ngc_job_id)) + + _download_from_ngc(ngc_job_id, to_download, ckpt_target_dir, tmp_dir=tmp_dir) + return ckpt_target_dir + + +def get_local_ngc_checkpoint_dir(ngc_job_id, ckpt_root_dir): + for p in glob(ckpt_root_dir + "/*"): + if str(ngc_job_id) == p.split("_")[-1] or str(ngc_job_id) == p.split("/")[-1]: + return p + return None + + +def get_checkpoint( + ckpt_key, + ngc_job_id=None, + ckpt_dir=None, + ckpt_root_dir="checkpoints/", + download_tmp_dir="/tmp", +): + """ + Get checkpoint and config path given either a ngc job ID or a local dir. + + If a @ngc_job_id is specified, the function will first look for a directory that ends with @ngc_job_id inside + @ckpt_root_dir. E.g., if @ngc_job_id == `1234567`, and @ckpt_root_dir == "checkpoints/", the function will look for + a directory "checkpoints/*_1234567", and within the directory a `.ckpt` file that contains @ckpt_key. + If no such directory or checkpoint file exists, it will try to download the checkpoint from + NGC under the result directory of a job and save it locally such that it will not need to download it again the + next time that this function is invoked. + + If a @ckpt_dir is specified, the function will look for the directory locally and return the ckpt that contains + @ckpt_key, as well as its config.json. + + Args: + ckpt_key (str): a string that uniquely identifies a checkpoint file with a directory, e.g., `iter50000.ckpt` + ngc_job_id (str): (Optional) ngc job ID of the checkpoint if the training was done on NGC. + ckpt_dir (str): (Optional) a local directory that contains the specified checkpoint + ckpt_root_dir (str): (Optional) a directory that the function will look for checkpoints downloaded from NGC + download_tmp_dir (str): a temporary storage for the checkpoint. + + Returns: + ckpt_path (str): path to a checkpoint file + cfg_path (str): path to a config.json file + """ + + def ckpt_path_func(paths): + return [p for p in paths if str(ckpt_key) in p] + + if ngc_job_id is None or len(str(ngc_job_id)) == 0: + local_dir = ckpt_dir + assert ckpt_dir is not None + else: + local_dir = get_local_ngc_checkpoint_dir(ngc_job_id, ckpt_root_dir) + + if local_dir is None: + print("checkpoint does not exist, downloading ...") + ckpt_dir = download_checkpoints_from_ngc( + ngc_job_id=str(ngc_job_id), + ckpt_root_dir=ckpt_root_dir, + ckpt_path_func=ckpt_path_func, + tmp_dir=download_tmp_dir, + ) + else: + ckpt_paths = glob(local_dir + "/**/*.ckpt", recursive=True) + if len(ckpt_path_func(ckpt_paths)) == 0: + if ngc_job_id is not None: + print("checkpoint does not exist, downloading ...") + ckpt_dir = download_checkpoints_from_ngc( + ngc_job_id=str(ngc_job_id), + ckpt_root_dir=ckpt_root_dir, + ckpt_path_func=ckpt_path_func, + tmp_dir=download_tmp_dir, + ) + else: + raise FileNotFoundError( + "Cannot find checkpoint in {} with key {}".format( + local_dir, ckpt_key + ) + ) + else: + ckpt_dir = local_dir + + ckpt_paths = ckpt_path_func(glob(ckpt_dir + "/**/*.ckpt", recursive=True)) + assert len(ckpt_paths) > 0, "Could not find a checkpoint that has key {}".format( + ckpt_key + ) + assert len(ckpt_paths) == 1, "More than one checkpoint found {}".format(ckpt_paths) + cfg_path = glob(ckpt_dir + "/**/config.json", recursive=True)[0] + print("Checkpoint path: {}".format(ckpt_paths[0])) + print("Config path: {}".format(cfg_path)) + return ckpt_paths[0], cfg_path + + +if __name__ == "__main__": + # print(get_checkpoint("2546043", ckpt_key="iter87999_")) + pass diff --git a/diffstack/utils/fp16_util.py b/diffstack/utils/fp16_util.py new file mode 100644 index 0000000..23e0418 --- /dev/null +++ b/diffstack/utils/fp16_util.py @@ -0,0 +1,76 @@ +""" +Helpers to train with 16-bit precision. +""" + +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + l.bias.data = l.bias.data.float() + + +def make_master_params(model_params): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + model_params = list(model_params) + + for param, master_param in zip( + model_params, unflatten_master_params(model_params, master_params) + ): + param.detach().copy_(master_param) + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() diff --git a/diffstack/utils/geometry_utils.py b/diffstack/utils/geometry_utils.py new file mode 100644 index 0000000..d2b980d --- /dev/null +++ b/diffstack/utils/geometry_utils.py @@ -0,0 +1,567 @@ +import numpy as np + +import torch +from diffstack.utils.tensor_utils import round_2pi +from enum import IntEnum + +# Expose some functions from trajdata +from trajdata.utils.arr_utils import ( + transform_matrices, + batch_nd_transform_points_pt, + batch_nd_transform_points_np, +) + + +def get_box_agent_coords(pos, yaw, extent): + corners = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5).to( + pos.device + ) * (extent.unsqueeze(-2)) + s = torch.sin(yaw).unsqueeze(-1) + c = torch.cos(yaw).unsqueeze(-1) + rotM = torch.cat((torch.cat((c, s), dim=-1), torch.cat((-s, c), dim=-1)), dim=-2) + rotated_corners = (corners + pos.unsqueeze(-2)) @ rotM + return rotated_corners + + +def get_box_world_coords(pos, yaw, extent): + corners = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5).to( + pos.device + ) * (extent.unsqueeze(-2)) + s = torch.sin(yaw).unsqueeze(-1) + c = torch.cos(yaw).unsqueeze(-1) + rotM = torch.cat((torch.cat((c, s), dim=-1), torch.cat((-s, c), dim=-1)), dim=-2) + rotated_corners = corners @ rotM + pos.unsqueeze(-2) + return rotated_corners + + +def get_box_agent_coords_np(pos, yaw, extent): + corners = (np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5) * ( + extent[..., None, :] + ) + s = np.sin(yaw)[..., None] + c = np.cos(yaw)[..., None] + rotM = np.concatenate( + (np.concatenate((c, s), axis=-1), np.concatenate((-s, c), axis=-1)), axis=-2 + ) + rotated_corners = (corners + pos[..., None, :]) @ rotM + return rotated_corners + + +def extents_to_corners(extents_lw: torch.Tensor, xyh: torch.Tensor) -> torch.Tensor: + """ + Args: + extents_lw: (..., 2) (length, width) + xyh: (..., 3) + Returns: + box_points_xy (..., 4, 2) + + """ + assert extents_lw.shape[-1] == 2 and xyh.shape[-1] == 3 + assert extents_lw.ndim == xyh.ndim + extents_lw = extents_lw.float() + + rel_points = torch.tensor( + [[0.5, 0.5], [-0.5, 0.5], [-0.5, -0.5], [0.5, -0.5]], + dtype=torch.float, + device=extents_lw.device, + ) + rel_points = extents_lw.unsqueeze(-2) * rel_points.view( + ([1] * (extents_lw.ndim - 1)) + [4, 2] + ) + + xy, h = torch.split(xyh, (2, 1), dim=-1) + tf = transform_matrices(h.squeeze(-1).double(), xy.double()) + box_points_xy = batch_nd_transform_points_pt(rel_points.double(), tf) + return box_points_xy.type_as(xyh) + + +def get_box_world_coords_np(pos, yaw, extent): + corners = (np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5) * ( + extent[..., None, :] + ) + s = np.sin(yaw)[..., None] + c = np.cos(yaw)[..., None] + rotM = np.concatenate( + (np.concatenate((c, s), axis=-1), np.concatenate((-s, c), axis=-1)), axis=-2 + ) + rotated_corners = corners @ rotM + pos[..., None, :] + return rotated_corners + + +def get_upright_box(pos, extent): + yaws = torch.zeros(*pos.shape[:-1], 1).to(pos.device) + boxes = get_box_world_coords(pos, yaws, extent) + upright_boxes = boxes[..., [0, 2], :] + return upright_boxes + + +def batch_nd_transform_points(points, Mat): + ndim = Mat.shape[-1] - 1 + Mat = Mat.transpose(-1, -2) + return (points.unsqueeze(-2) @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + + +def transform_points_tensor( + points: torch.Tensor, transf_matrix: torch.Tensor +) -> torch.Tensor: + """ + Transform a set of 2D/3D points using the given transformation matrix. + Assumes row major ordering of the input points. The transform function has 3 modes: + - points (N, F), transf_matrix (F+1, F+1) + all points are transformed using the matrix and the output points have shape (N, F). + - points (B, N, F), transf_matrix (F+1, F+1) + all sequences of points are transformed using the same matrix and the output points have shape (B, N, F). + transf_matrix is broadcasted. + - points (B, N, F), transf_matrix (B, F+1, F+1) + each sequence of points is transformed using its own matrix and the output points have shape (B, N, F). + Note this function assumes points.shape[-1] == matrix.shape[-1] - 1, which means that last + rows in the matrices do not influence the final results. + For 2D points only the first 2x3 parts of the matrices will be used. + + :param points: Input points of shape (N, F) or (B, N, F) + with F = 2 or 3 depending on input points are 2D or 3D points. + :param transf_matrix: Transformation matrix of shape (F+1, F+1) or (B, F+1, F+1) with F = 2 or 3. + :return: Transformed points of shape (N, F) or (B, N, F) depending on the dimensions of the input points. + """ + points_log = f" received points with shape {points.shape} " + matrix_log = f" received matrices with shape {transf_matrix.shape} " + + assert points.ndim in [2, 3], f"points should have ndim in [2,3],{points_log}" + assert transf_matrix.ndim in [ + 2, + 3, + ], f"matrix should have ndim in [2,3],{matrix_log}" + assert ( + points.ndim >= transf_matrix.ndim + ), f"points ndim should be >= than matrix,{points_log},{matrix_log}" + + points_feat = points.shape[-1] + assert points_feat in [2, 3], f"last points dimension must be 2 or 3,{points_log}" + assert ( + transf_matrix.shape[-1] == transf_matrix.shape[-2] + ), f"matrix should be a square matrix,{matrix_log}" + + matrix_feat = transf_matrix.shape[-1] + assert matrix_feat in [3, 4], f"last matrix dimension must be 3 or 4,{matrix_log}" + assert ( + points_feat == matrix_feat - 1 + ), f"points last dim should be one less than matrix,{points_log},{matrix_log}" + + def _transform(points: torch.Tensor, transf_matrix: torch.Tensor) -> torch.Tensor: + num_dims = transf_matrix.shape[-1] - 1 + transf_matrix = torch.permute(transf_matrix, (0, 2, 1)) + return ( + points @ transf_matrix[:, :num_dims, :num_dims] + + transf_matrix[:, -1:, :num_dims] + ) + + if points.ndim == transf_matrix.ndim == 2: + points = torch.unsqueeze(points, 0) + transf_matrix = torch.unsqueeze(transf_matrix, 0) + return _transform(points, transf_matrix)[0] + + elif points.ndim == transf_matrix.ndim == 3: + return _transform(points, transf_matrix) + + elif points.ndim == 3 and transf_matrix.ndim == 2: + transf_matrix = torch.unsqueeze(transf_matrix, 0) + return _transform(points, transf_matrix) + else: + raise NotImplementedError(f"unsupported case!{points_log},{matrix_log}") + + +def PED_PED_collision(p1, p2, S1, S2): + if isinstance(p1, torch.Tensor): + return ( + torch.linalg.norm(p1[..., 0:2] - p2[..., 0:2], dim=-1) + - (S1[..., 0] + S2[..., 0]) / 2 + ) + + elif isinstance(p1, np.ndarray): + return ( + np.linalg.norm(p1[..., 0:2] - p2[..., 0:2], axis=-1) + - (S1[..., 0] + S2[..., 0]) / 2 + ) + else: + raise NotImplementedError + + +def transform_xy_coordinate_frame(traj_xy, anchor_xyh): + if isinstance(traj_xy, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + x, y = np.split(traj_xy, 2, axis=-1) + # Move poses to origin + x = x - delta_x + y = y - delta_y + # Rotate around new origin with -delta_h + traj_xy = batch_rotate_2D(np.concatenate((x, y), -1), -delta_h) + elif isinstance(traj_xy, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, dim=-1) + x, y = torch.split(traj_xy, 2, dim=-1) + # Move poses to origin + x = x - delta_x + y = y - delta_y + # Rotate around new origin with -delta_h + traj_xy = batch_rotate_2D(torch.cat((x, y), -1), -delta_h) + return traj_xy + + +def transform_xyh_coordinate_frame(traj_xyh, anchor_xyh): + if isinstance(traj_xyh, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + xy, h = np.split(traj_xyh, (2,), axis=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + traj_xyh = np.concatenate((xy, h), axis=-1) + elif isinstance(traj_xyh, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, axis=-1) + xy, h = torch.split(traj_xyh, (2, 1), dim=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + traj_xyh = torch.cat((xy, h), dim=-1) + return traj_xyh + + +def transform_xyhvv_coordinate_frame(traj_xyhvv, anchor_xyh): + if isinstance(traj_xyhvv, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + xy, h, vx, vy = np.split( + traj_xyhvv, + ( + 2, + 3, + 4, + ), + axis=-1, + ) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + vxy = batch_rotate_2D(np.concatenate((vx, vy), -1), -delta_h) + traj_xyhvv = np.concatenate((xy, h, vxy), axis=-1) + elif isinstance(traj_xyhvv, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, axis=-1) + xy, h, vx, vy = torch.split(traj_xyhvv, (2, 1, 1, 1), dim=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + vxy = batch_rotate_2D(torch.cat((vx, vy), -1), -delta_h) + traj_xyhvv = torch.cat((xy, h, vxy), dim=-1) + return traj_xyhvv + + +def batch_rotate_2D(xy, theta): + if isinstance(xy, torch.Tensor): + x1 = xy[..., 0] * torch.cos(theta) - xy[..., 1] * torch.sin(theta) + y1 = xy[..., 1] * torch.cos(theta) + xy[..., 0] * torch.sin(theta) + return torch.stack([x1, y1], dim=-1) + elif isinstance(xy, np.ndarray): + x1 = xy[..., 0] * np.cos(theta) - xy[..., 1] * np.sin(theta) + y1 = xy[..., 1] * np.cos(theta) + xy[..., 0] * np.sin(theta) + return np.stack((x1, y1), axis=-1) + + +def VEH_VEH_collision(p1, p2, S1, S2, offsetX=0.0, offsetY=0.0): + if isinstance(p1, torch.Tensor): + cornersX = torch.kron( + S1[..., 0] + offsetX, torch.tensor([0.5, 0.5, -0.5, -0.5], device=p1.device) + ) + cornersY = torch.kron( + S1[..., 1] + offsetY, torch.tensor([0.5, -0.5, 0.5, -0.5], device=p1.device) + ) + corners = torch.stack([cornersX, cornersY], dim=-1) + theta1 = p1[..., 2] + theta2 = p2[..., 2] + dx = (p1[..., 0:2] - p2[..., 0:2]).repeat_interleave(4, dim=-2) + delta_x1 = batch_rotate_2D(corners, theta1.repeat_interleave(4, dim=-1)) + dx + delta_x2 = batch_rotate_2D(delta_x1, -theta2.repeat_interleave(4, dim=-1)) + dis = torch.maximum( + torch.abs(delta_x2[..., 0]) - 0.5 * S2[..., 0].repeat_interleave(4, dim=-1), + torch.abs(delta_x2[..., 1]) - 0.5 * S2[..., 1].repeat_interleave(4, dim=-1), + ).view(*S1.shape[:-1], 4) + min_dis, _ = torch.min(dis, dim=-1) + + return min_dis + + elif isinstance(p1, np.ndarray): + cornersX = np.kron(S1[..., 0] + offsetX, np.array([0.5, 0.5, -0.5, -0.5])) + cornersY = np.kron(S1[..., 1] + offsetY, np.array([0.5, -0.5, 0.5, -0.5])) + corners = np.concatenate((cornersX, cornersY), axis=-1) + theta1 = p1[..., 2] + theta2 = p2[..., 2] + dx = (p1[..., 0:2] - p2[..., 0:2]).repeat(4, axis=-2) + delta_x1 = batch_rotate_2D(corners, theta1.repeat(4, axis=-1)) + dx + delta_x2 = batch_rotate_2D(delta_x1, -theta2.repeat(4, axis=-1)) + dis = np.maximum( + np.abs(delta_x2[..., 0]) - 0.5 * S2[..., 0].repeat(4, axis=-1), + np.abs(delta_x2[..., 1]) - 0.5 * S2[..., 1].repeat(4, axis=-1), + ).reshape(*S1.shape[:-1], 4) + min_dis = np.min(dis, axis=-1) + return min_dis + else: + raise NotImplementedError + + +def VEH_PED_collision(p1, p2, S1, S2): + if isinstance(p1, torch.Tensor): + mask = torch.logical_or( + torch.abs(p1[..., 2]) > 0.1, torch.linalg.norm(p2[..., 2:4], dim=-1) > 0.1 + ).detach() + theta = p1[..., 2] + dx = batch_rotate_2D(p2[..., 0:2] - p1[..., 0:2], -theta) + + return torch.maximum( + torch.abs(dx[..., 0]) - S1[..., 0] / 2 - S2[..., 0] / 2, + torch.abs(dx[..., 1]) - S1[..., 1] / 2 - S2[..., 0] / 2, + ) + elif isinstance(p1, np.ndarray): + theta = p1[..., 2] + dx = batch_rotate_2D(p2[..., 0:2] - p1[..., 0:2], -theta) + return np.maximum( + np.abs(dx[..., 0]) - S1[..., 0] / 2 - S2[..., 0] / 2, + np.abs(dx[..., 1]) - S1[..., 1] / 2 - S2[..., 0] / 2, + ) + else: + raise NotImplementedError + + +def PED_VEH_collision(p1, p2, S1, S2): + return VEH_PED_collision(p2, p1, S2, S1) + + +def batch_proj(x, line): + # x:[batch,3], line:[batch,N,3] + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( + line_min[..., None, 2] + ) + delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( + line_min[..., None, 2] + ) + # ref_pts = torch.stack( + # [ + # line_min[..., 0] + delta_x * torch.cos(line_min[..., 2]), + # line_min[..., 1] + delta_x * torch.sin(line_min[..., 2]), + # line_min[..., 2], + # ], + # dim=-1, + # ) + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + + return ( + delta_x, + delta_y, + torch.unsqueeze(delta_psi, dim=-1), + ) + elif isinstance(x, np.ndarray): + delta = line[..., 0:2] - np.repeat( + x[..., np.newaxis, 0:2], line_length, axis=-2 + ) + dis = np.linalg.norm(delta, axis=-1) + idx0 = np.argmin(dis, axis=-1) + idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) + line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( + line_min[..., None, 2] + ) + delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( + line_min[..., None, 2] + ) + # line_min[..., 0] += delta_x * np.cos(line_min[..., 2]) + # line_min[..., 1] += delta_x * np.sin(line_min[..., 2]) + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + return ( + delta_x, + delta_y, + np.expand_dims(delta_psi, axis=-1), + ) + + +def batch_proj_xysc(x, line): + # x:[batch,4](x,y,s,c), line:[batch,N,4](x,y,s,c) + # normalizing s and c + x = normalize_xysc(x) + line = normalize_xysc(line) + + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + x = normalize_xysc(x) + line = normalize_xysc(line) + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + + delta_x = dx * line_min[..., None, 3] + dy * line_min[..., None, 2] + delta_y = -dx * line_min[..., None, 2] + dy * line_min[..., None, 3] + + delta_s = x[..., None, 2] * line[..., 3] - x[..., None, 3] * line[..., 2] + delta_c = x[..., None, 2] * line[..., 2] + x[..., None, 3] * line[..., 3] + + return torch.stack((delta_x, delta_y, delta_s, delta_c), -1) + else: + raise NotImplementedError + + +def normalize_sc(x): + x_zero_flag = (x == 0).all(-1).type(x.dtype) + x = x + x_zero_flag[..., None] * torch.stack( + [torch.zeros_like(x_zero_flag), torch.ones_like(x_zero_flag)], -1 + ) + x = x / torch.norm(x, dim=-1, keepdim=True) + return x + + +def normalize_xysc(x): + xsc = normalize_sc(x[..., 2:4]) + return torch.cat([x[..., :2], xsc], -1) + + +class CollisionType(IntEnum): + """This enum defines the three types of collisions: front, rear and side.""" + + FRONT = 0 + REAR = 1 + SIDE = 2 + + +def detect_collision( + ego_pos: np.ndarray, + ego_yaw: np.ndarray, + ego_extent: np.ndarray, + other_pos: np.ndarray, + other_yaw: np.ndarray, + other_extent: np.ndarray, +): + """ + Computes whether a collision occured between ego and any another agent. + Also computes the type of collision: rear, front, or side. + For this, we compute the intersection of ego's four sides with a target + agent and measure the length of this intersection. A collision + is classified into a class, if the corresponding length is maximal, + i.e. a front collision exhibits the longest intersection with + egos front edge. + + .. note:: please note that this funciton will stop upon finding the first + colision, so it won't return all collisions but only the first + one found. + + :param ego_pos: predicted centroid + :param ego_yaw: predicted yaw + :param ego_extent: predicted extent + :param other_pos: target agents + :return: None if not collision was found, and a tuple with the + collision type and the agent track_id + """ + from l5kit.planning import utils + + ego_bbox = utils._get_bounding_box(centroid=ego_pos, yaw=ego_yaw, extent=ego_extent) + + # within_range_mask = utils.within_range(ego_pos, ego_extent, other_pos, other_extent) + for i in range(other_pos.shape[0]): + agent_bbox = utils._get_bounding_box( + other_pos[i], other_yaw[i], other_extent[i] + ) + if ego_bbox.intersects(agent_bbox): + front_side, rear_side, left_side, right_side = utils._get_sides(ego_bbox) + + intersection_length_per_side = np.asarray( + [ + agent_bbox.intersection(front_side).length, + agent_bbox.intersection(rear_side).length, + agent_bbox.intersection(left_side).length, + agent_bbox.intersection(right_side).length, + ] + ) + argmax_side = np.argmax(intersection_length_per_side) + + # Remap here is needed because there are two sides that are + # mapped to the same collision type CollisionType.SIDE + max_collision_types = max(CollisionType).value + remap_argmax = min(argmax_side, max_collision_types) + collision_type = CollisionType(remap_argmax) + return collision_type, i + return None + + +def calc_distance_map(road_flag, max_dis=10, mode="L1"): + """mark the image with manhattan distance to the drivable area + + Args: + road_flag (torch.Tensor[B,W,H]): an image with 1 channel, 1 for drivable area, 0 for non-drivable area + max_dis (int, optional): maximum distance that the result saturates to. Defaults to 10. + """ + out = torch.zeros_like(road_flag, dtype=torch.float) + out[road_flag == 0] = max_dis + out[road_flag == 1] = 0 + if mode == "L1": + for i in range(max_dis - 1): + out[..., 1:, :] = torch.min(out[..., 1:, :], out[..., :-1, :] + 1) + out[..., :-1, :] = torch.min(out[..., :-1, :], out[..., 1:, :] + 1) + out[..., :, 1:] = torch.min(out[..., :, 1:], out[..., :, :-1] + 1) + out[..., :, :-1] = torch.min(out[..., :, :-1], out[..., :, 1:] + 1) + elif mode == "Linf": + for i in range(max_dis - 1): + out[..., 1:, :] = torch.min(out[..., 1:, :], out[..., :-1, :] + 1) + out[..., :-1, :] = torch.min(out[..., :-1, :], out[..., 1:, :] + 1) + out[..., :, 1:] = torch.min(out[..., :, 1:], out[..., :, :-1] + 1) + out[..., :, :-1] = torch.min(out[..., :, :-1], out[..., :, 1:] + 1) + out[..., 1:, 1:] = torch.min(out[..., 1:, 1:], out[..., :-1, :-1] + 1) + out[..., 1:, :-1] = torch.min(out[..., 1:, :-1], out[..., :-1, 1:] + 1) + out[..., :-1, :-1] = torch.min(out[..., :-1, :-1], out[..., 1:, 1:] + 1) + out[..., :-1, 1:] = torch.min(out[..., :-1, 1:], out[..., 1:, :-1] + 1) + + return out + + +def rel_xysc(p1, p2): + """calculate the relative position of p2 to p1 + + Args: + p1 (torch.Tensor[B,4]): [x,y,s,c] + p2 (torch.Tensor[B,4]): [x,y,s,c] + + Returns: + torch.Tensor[B,4]: [dx,dy,ds,dc] + """ + # normalizing s and c + p1 = normalize_xysc(p1) + p2 = normalize_xysc(p2) + + dx = p2[..., 0] - p1[..., 0] + dy = p2[..., 1] - p1[..., 1] + dxl = dx * p1[..., 3] + dy * p1[..., 2] + dyl = -dx * p1[..., 2] + dy * p1[..., 3] + ds = p2[..., 2] * p1[..., 3] - p2[..., 3] * p1[..., 2] + dc = p2[..., 3] * p1[..., 3] + p2[..., 2] * p1[..., 2] + return torch.stack((dxl, dyl, ds, dc), -1) + + +def ratan2(s, c, eps=1e-4): + # robust arctan2 for pytorch + sign = (c >= 0).float() * 2 - 1 + eps = eps * (c.abs() < eps).type(c.dtype) * sign + return torch.arctan2(s, c + eps) diff --git a/diffstack/utils/homotopy.py b/diffstack/utils/homotopy.py new file mode 100644 index 0000000..6f3b665 --- /dev/null +++ b/diffstack/utils/homotopy.py @@ -0,0 +1,113 @@ +from Pplan.Sampling.spline_planner import SplinePlanner +from Pplan.Sampling.trajectory_tree import TrajTree + +import torch +import numpy as np +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.geometry_utils import ratan2 +from diffstack.utils.tree import Tree +from typing import List +from enum import IntEnum + + +HOMOTOPY_THRESHOLD = np.pi / 6 + + +class HomotopyType(IntEnum): + """ + Homotopy class between two paths + STATIC: relatively small wraping angle + CW: clockwise + CCW: counter-clockwise + """ + + STATIC = 0 + CW = 1 + CCW = 2 + + @staticmethod + def enforce_symmetry(x, mode="U"): + assert x.shape[-2] == x.shape[-1] + xT = x.transpose(-1, -2).clone() + + diag_mask = torch.eye(x.shape[-1], device=x.device).bool().expand(*x.shape) + x.masked_fill_(diag_mask, HomotopyType.STATIC) + + if mode == "U": + # force symmetry based on upper triangular matrix + triangle = torch.tril(torch.ones_like(x), diagonal=-1) + x = x * (1 - triangle) + xT * triangle + elif mode == "L": + # force symmetry based on lower triangular matrix + triangle = torch.triu(torch.ones_like(x), diagonal=1) + x = x * (1 - triangle) + xT * triangle + return x + + +def mag_integral(path0, path1, mask=None): + if isinstance(path0, torch.Tensor): + delta_path = path0 - path1 + close_flag = torch.norm(delta_path, dim=-1) < 1e-3 + angle = ratan2(delta_path[..., 1], delta_path[..., 0]).masked_fill( + close_flag, 0 + ) + delta_angle = GeoUtils.round_2pi(angle[..., 1:] - angle[..., :-1]) + if mask is not None: + if mask.ndim == delta_angle.ndim - 1: + delta_angle = delta_angle * mask[..., None] + elif mask.ndim == delta_angle.ndim: + diff_mask = mask[..., 1:] * mask[..., :-1] + delta_angle = delta_angle * diff_mask + angle_diff = torch.sum(delta_angle, dim=-1) + + elif isinstance(path0, torch.ndarray): + delta_path = path0 - path1 + close_flag = (np.norm(delta_path, dim=-1) < 1e-3).float() + angle = np.arctan2(delta_path[..., 1], delta_path[..., 0]) * (1 - close_flag) + delta_angle = GeoUtils.round_2pi(angle[..., 1:] - angle[..., :-1]) + if mask is not None: + if mask.ndim == delta_angle.ndim - 1: + delta_angle = delta_angle * mask[..., None] + elif mask.ndim == delta_angle.ndim: + diff_mask = mask[..., 1:] * mask[..., :-1] + delta_angle = delta_angle * diff_mask + angle_diff = np.sum(delta_angle, axis=-1) + return angle_diff + + +def identify_homotopy( + ego_path: torch.Tensor, obj_paths: torch.Tensor, threshold=HOMOTOPY_THRESHOLD +): + """Identifying homotopy classes for the ego + + Args: + ego_path (torch.Tensor): B x T x 2 + obj_paths (torch.Tensor): B x M x N x T x 2 + """ + b, M, N = obj_paths.shape[:3] + angle_diff = mag_integral(ego_path[:, None, None], obj_paths) + homotopy = torch.zeros([b, M, N], device=ego_path.device) + homotopy[angle_diff >= threshold] = HomotopyType.CCW + homotopy[angle_diff <= -threshold] = HomotopyType.CW + homotopy[(angle_diff > -threshold) & (angle_diff < threshold)] = HomotopyType.STATIC + + return angle_diff, homotopy + + +def identify_pairwise_homotopy( + path: torch.Tensor, threshold=HOMOTOPY_THRESHOLD, mask=None +): + """ + Args: + path (torch.Tensor): B x N x T x 2 + """ + b, N, T = path.shape[:3] + if mask is not None: + mask = mask[:, None] * mask[:, :, None] + angle_diff = mag_integral(path[:, :, None], path[:, None], mask=mask) + homotopy = torch.zeros([b, N, N], device=path.device) + homotopy[angle_diff >= threshold] = HomotopyType.CCW + homotopy[angle_diff <= -threshold] = HomotopyType.CW + homotopy[(angle_diff > -threshold) & (angle_diff < threshold)] = HomotopyType.STATIC + + return angle_diff, homotopy diff --git a/diffstack/utils/kalman_filter.py b/diffstack/utils/kalman_filter.py new file mode 100644 index 0000000..85e6866 --- /dev/null +++ b/diffstack/utils/kalman_filter.py @@ -0,0 +1,120 @@ +import numpy as np + +class NonlinearKinematicBicycle: + """ + Nonlinear Kalman Filter for a kinematic bicycle model, assuming constant longitudinal speed + and constant heading array + """ + + def __init__(self, dt, sPos=None, sHeading=None, sVel=None, sMeasurement=None): + self.dt = dt + + # measurement matrix + self.C = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + + # default noise covariance + if (sPos is None) and (sHeading is None) and (sVel is None): + # TODO need to further check + # sPos = 0.5 * 8.8 * dt ** 2 # assume 8.8m/s2 as maximum acceleration + # sHeading = 0.5 * dt # assume 0.5rad/s as maximum turn rate + # sVel = 8.8 * dt # assume 8.8m/s2 as maximum acceleration + # sMeasurement = 1.0 + sPos = 12 * self.dt # assume 6m/s2 as maximum acceleration + sHeading = 0.5 * self.dt # assume 0.5rad/s as maximum turn rate + sVel = 6 * self.dt # assume 6m/s2 as maximum acceleration + if sMeasurement is None: + sMeasurement = 5.0 + # state transition noise + self.Q = np.diag([sPos ** 2, sPos ** 2, sHeading ** 2, sVel ** 2]) + # measurement noise + self.R = np.diag([sMeasurement ** 2, sMeasurement ** 2, sMeasurement ** 2, sMeasurement ** 2]) + + def predict_and_update(self, x_vec_est, u_vec, P_matrix, z_new): + """ + for background please refer to wikipedia: https://en.wikipedia.org/wiki/Extended_Kalman_filter + :param x_vec_est: + :param u_vec: + :param P_matrix: + :param z_new: + :return: + """ + + ## Prediction Step + # predicted state estimate + x_pred = self._kinematic_bicycle_model_rearCG(x_vec_est, u_vec) + # Compute Jacobian to obtain the state transition matrix + A = self._cal_state_Jacobian(x_vec_est, u_vec) + # predicted error covariance + P_pred = A.dot(P_matrix.dot(A.transpose())) + self.Q + + ## Update Step + # innovation or measurement pre-fit residual + y_telda = z_new - self.C.dot(x_pred) + # innovation covariance + S = self.C.dot(P_pred.dot(self.C.transpose())) + self.R + # near-optimal Kalman gain + K = P_pred.dot(self.C.transpose().dot(np.linalg.inv(S))) + # updated (a posteriori) state estimate + x_vec_est_new = x_pred + K.dot(y_telda) + # updated (a posteriori) estimate covariance + P_matrix_new = np.dot((np.identity(4) - K.dot(self.C)), P_pred) + + return x_vec_est_new, P_matrix_new + + def _kinematic_bicycle_model_rearCG(self, x_old, u): + """ + :param x: vehicle state vector = [x position, y position, heading, velocity] + :param u: control vector = [acceleration, steering array] + :param dt: + :return: + """ + + acc = u[0] + delta = u[1] + + x = x_old[0] + y = x_old[1] + psi = x_old[2] + vel = x_old[3] + + x_new = np.array([[0.], [0.], [0.], [0.]]) + + x_new[0] = x + self.dt * vel * np.cos(psi + delta) + x_new[1] = y + self.dt * vel * np.sin(psi + delta) + x_new[2] = psi + self.dt * delta + #x_new[2] = _heading_angle_correction(x_new[2]) + x_new[3] = vel + self.dt * acc + + return x_new + + def _cal_state_Jacobian(self, x_vec, u_vec): + acc = u_vec[0] + delta = u_vec[1] + + x = x_vec[0] + y = x_vec[1] + psi = x_vec[2] + vel = x_vec[3] + + a13 = -self.dt * vel * np.sin(psi + delta) + a14 = self.dt * np.cos(psi + delta) + a23 = self.dt * vel * np.cos(psi + delta) + a24 = self.dt * np.sin(psi + delta) + a34 = self.dt * delta + + JA = np.array([[1.0, 0.0, a13[0], a14[0]], + [0.0, 1.0, a23[0], a24[0]], + [0.0, 0.0, 1.0, a34[0]], + [0.0, 0.0, 0.0, 1.0]]) + + return JA + + +def _heading_angle_correction(theta): + """ + correct heading array so that it always remains in [-pi, pi] + :param theta: + :return: + """ + theta_corrected = (theta + np.pi) % (2.0 * np.pi) - np.pi + return theta_corrected \ No newline at end of file diff --git a/diffstack/utils/l5_utils.py b/diffstack/utils/l5_utils.py new file mode 100644 index 0000000..5a03401 --- /dev/null +++ b/diffstack/utils/l5_utils.py @@ -0,0 +1,726 @@ +import torch +import torch.nn.functional as F + +import diffstack.dynamics as dynamics +import diffstack.utils.tensor_utils as TensorUtils +from diffstack import dynamics as dynamics +from diffstack.configs.base import ExperimentConfig + + +def get_agent_masks(raw_type): + """ + PERCEPTION_LABELS = [ + "PERCEPTION_LABEL_NOT_SET", + "PERCEPTION_LABEL_UNKNOWN", + "PERCEPTION_LABEL_DONTCARE", + "PERCEPTION_LABEL_CAR", + "PERCEPTION_LABEL_VAN", + "PERCEPTION_LABEL_TRAM", + "PERCEPTION_LABEL_BUS", + "PERCEPTION_LABEL_TRUCK", + "PERCEPTION_LABEL_EMERGENCY_VEHICLE", + "PERCEPTION_LABEL_OTHER_VEHICLE", + "PERCEPTION_LABEL_BICYCLE", + "PERCEPTION_LABEL_MOTORCYCLE", + "PERCEPTION_LABEL_CYCLIST", + "PERCEPTION_LABEL_MOTORCYCLIST", + "PERCEPTION_LABEL_PEDESTRIAN", + "PERCEPTION_LABEL_ANIMAL", + "AVRESEARCH_LABEL_DONTCARE", + ] + """ + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + # veh_mask = veh_mask | ped_mask + # ped_mask = ped_mask * 0 + return veh_mask, ped_mask + + +def get_dynamics_types(veh_mask, ped_mask): + dyn_type = torch.zeros_like(veh_mask) + dyn_type += dynamics.DynType.UNICYCLE * veh_mask + dyn_type += dynamics.DynType.DI * ped_mask + return dyn_type + + +def raw_to_features(batch_raw): + """ map raw src into features of dim 21 """ + raw_type = batch_raw["raw_types"] + pos = batch_raw["hist_pos"] + vel = batch_raw["history_velocities"] + yaw = batch_raw["hist_yaw"] + mask = batch_raw["hist_mask"] + + veh_mask, ped_mask = get_agent_masks(raw_type) + + # all vehicles, cyclists, and motorcyclists + feature_veh = torch.cat((pos, vel, torch.cos(yaw), torch.sin(yaw)), dim=-1) + + # pedestrians and animals + ped_feature = torch.cat( + (pos, vel, vel * torch.sin(yaw), vel * torch.cos(yaw)), dim=-1 + ) + + feature = feature_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + ped_feature * ped_mask.view([*raw_type.shape, 1, 1]) + + type_embedding = F.one_hot(raw_type, 16) + + feature = torch.cat( + (feature, type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1)), + dim=-1, + ) + feature = feature * mask.unsqueeze(-1) + + return feature + + +def raw_to_states(batch_raw): + raw_type = batch_raw["raw_types"] + pos = batch_raw["hist_pos"] + vel = batch_raw["history_velocities"] + yaw = batch_raw["hist_yaw"] + avail_mask = batch_raw["hist_mask"] + + veh_mask, ped_mask = get_agent_masks(raw_type) # [B, (A)] + + # all vehicles, cyclists, and motorcyclists + state_veh = torch.cat((pos, vel, yaw), dim=-1) # [B, (A), T, S] + # pedestrians and animals + state_ped = torch.cat((pos, vel * torch.cos(yaw), vel * torch.sin(yaw)), dim=-1) # [B, (A), T, S] + + state = state_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + state_ped * ped_mask.view([*raw_type.shape, 1, 1]) # [B, (A), T, S] + + # Get the current state of the agents + num = torch.arange(0, avail_mask.shape[-1]).view(1, 1, -1).to(avail_mask.device) + nummask = num * avail_mask + last_idx, _ = torch.max(nummask, dim=2) + curr_state = torch.gather( + state, 2, last_idx[..., None, None].repeat(1, 1, 1, 4) + ) + return state, curr_state + + +def batch_to_raw_ego(data_batch, step_time): + batch_size = data_batch["hist_pos"].shape[0] + raw_type = torch.ones(batch_size).type(torch.int64).to(data_batch["hist_pos"].device) # [B, T] + raw_type = raw_type * 3 # index for type PERCEPTION_LABEL_CAR + + src_pos = torch.flip(data_batch["hist_pos"], dims=[-2]) + src_yaw = torch.flip(data_batch["hist_yaw"], dims=[-2]) + src_mask = torch.flip(data_batch["hist_mask"], dims=[-1]).bool() + + src_vel = dynamics.Unicycle(step_time).calculate_vel(pos=src_pos, yaw=src_yaw, mask=src_mask) + src_vel[:, -1] = data_batch["curr_speed"].unsqueeze(-1) + + raw = { + "hist_pos": src_pos, + "history_velocities": src_vel, + "hist_yaw": src_yaw, + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": data_batch["extents"] + } + + raw = TensorUtils.unsqueeze(raw, dim=1) # Add the agent dimension + return raw + + +def raw2feature(pos, vel, yaw, raw_type, mask, lanes=None, add_noise=False): + "map raw src into features of dim 21+lane dim" + + """ + PERCEPTION_LABELS = [ + "PERCEPTION_LABEL_NOT_SET", + "PERCEPTION_LABEL_UNKNOWN", + "PERCEPTION_LABEL_DONTCARE", + "PERCEPTION_LABEL_CAR", + "PERCEPTION_LABEL_VAN", + "PERCEPTION_LABEL_TRAM", + "PERCEPTION_LABEL_BUS", + "PERCEPTION_LABEL_TRUCK", + "PERCEPTION_LABEL_EMERGENCY_VEHICLE", + "PERCEPTION_LABEL_OTHER_VEHICLE", + "PERCEPTION_LABEL_BICYCLE", + "PERCEPTION_LABEL_MOTORCYCLE", + "PERCEPTION_LABEL_CYCLIST", + "PERCEPTION_LABEL_MOTORCYCLIST", + "PERCEPTION_LABEL_PEDESTRIAN", + "PERCEPTION_LABEL_ANIMAL", + "AVRESEARCH_LABEL_DONTCARE", + ] + """ + dyn_type = torch.zeros_like(raw_type) + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + veh_mask = veh_mask | ped_mask + ped_mask = ped_mask * 0 + dyn_type += dynamics.DynType.UNICYCLE * veh_mask + # all vehicles, cyclists, and motorcyclists + if add_noise: + pos_noise = torch.randn(pos.size(0), 1, 1, 2).to(pos.device) * 0.5 + yaw_noise = torch.randn(pos.size(0), 1, 1, 1).to(pos.device) * 0.1 + if pos.ndim == 5: + pos_noise = pos_noise.unsqueeze(1) + yaw_noise = yaw_noise.unsqueeze(1) + feature_veh = torch.cat( + ( + pos + pos_noise, + vel, + torch.cos(yaw + yaw_noise), + torch.sin(yaw + yaw_noise), + ), + dim=-1, + ) + else: + feature_veh = torch.cat((pos, vel, torch.cos(yaw), torch.sin(yaw)), dim=-1) + + state_veh = torch.cat((pos, vel, yaw), dim=-1) + + # pedestrians and animals + if add_noise: + pos_noise = torch.randn(pos.size(0), 1, 1, 2).to(pos.device) * 0.5 + yaw_noise = torch.randn(pos.size(0), 1, 1, 1).to(pos.device) * 0.1 + if pos.ndim == 5: + pos_noise = pos_noise.unsqueeze(1) + yaw_noise = yaw_noise.unsqueeze(1) + ped_feature = torch.cat( + ( + pos + pos_noise, + vel, + vel * torch.sin(yaw + yaw_noise), + vel * torch.cos(yaw + yaw_noise), + ), + dim=-1, + ) + else: + ped_feature = torch.cat( + (pos, vel, vel * torch.sin(yaw), vel * torch.cos(yaw)), dim=-1 + ) + state_ped = torch.cat((pos, vel * torch.cos(yaw), vel * torch.sin(yaw)), dim=-1) + state = state_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + state_ped * ped_mask.view([*raw_type.shape, 1, 1]) + dyn_type += dynamics.DynType.DI * ped_mask + + feature = feature_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + ped_feature * ped_mask.view([*raw_type.shape, 1, 1]) + + type_embedding = F.one_hot(raw_type, 16) + + if pos.ndim == 4: + if lanes is not None: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1), + lanes[:, :, None, :].repeat(1, 1, feature.size(2), 1), + ), + dim=-1, + ) + else: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1), + ), + dim=-1, + ) + + elif pos.ndim == 5: + if lanes is not None: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, 1, feature.size(-2), 1), + lanes[:, :, None, None, :].repeat( + 1, feature.size(1), 1, feature.size(2), 1 + ), + ), + dim=-1, + ) + else: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, 1, feature.size(-2), 1), + ), + dim=-1, + ) + feature = feature * mask.unsqueeze(-1) + return feature, dyn_type, state + + +def batch_to_vectorized_feature(data_batch, dyn_list, step_time, algo_config): + device = data_batch["hist_pos"].device + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["all_other_agents_hist_pos"], + ), + dim=1, + ) + "history position and yaw need to be flipped so that they go from past to recent" + src_pos = torch.flip(src_pos, dims=[-2]) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["all_other_agents_hist_yaw"], + ), + dim=1, + ) + src_yaw = torch.flip(src_yaw, dims=[-2]) + src_world_yaw = src_yaw + ( + data_batch["yaw"] + .view(-1, 1, 1, 1) + .repeat(1, src_yaw.size(1), src_yaw.size(2), 1) + ).type(torch.float) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["all_other_agents_history_availability"], + ), + dim=1, + ).bool() + + src_mask = torch.flip(src_mask, dims=[-1]) + # estimate velocity + src_vel = dyn_list[dynamics.DynType.UNICYCLE].calculate_vel( + src_pos, src_yaw, src_mask + ) + + src_vel[:, 0, -1] = torch.clip( + data_batch["curr_speed"].unsqueeze(-1), + min=algo_config.vmin, + max=algo_config.vmax, + ) + if algo_config.vectorize_lane: + src_lanes = torch.cat( + ( + data_batch["ego_lanes"].unsqueeze(1), + data_batch["all_other_agents_lanes"], + ), + dim=1, + ).type(torch.float) + src_lanes = torch.cat(( + src_lanes[...,0:2], + torch.cos(src_lanes[...,2:3]), + torch.sin(src_lanes[...,2:3]), + src_lanes[...,-1:], + ),dim=-1) + src_lanes = src_lanes.view(*src_lanes.shape[:2], -1) + else: + src_lanes = None + src, dyn_type, src_state = raw2feature( + src_pos, src_vel, src_yaw, raw_type, src_mask, src_lanes + ) + tgt_mask = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["all_other_agents_future_availability"], + ), + dim=1, + ).bool() + num = torch.arange(0, src_mask.shape[2]).view(1, 1, -1).to(src_mask.device) + nummask = num * src_mask + last_idx, _ = torch.max(nummask, dim=2) + curr_state = torch.gather( + src_state, 2, last_idx[..., None, None].repeat(1, 1, 1, 4) + ) + + tgt_pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["all_other_agents_future_positions"], + ), + dim=1, + ) + tgt_yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["all_other_agents_future_yaws"], + ), + dim=1, + ) + tgt_pos_yaw = torch.cat((tgt_pos, tgt_yaw), dim=-1) + + + # curr_pos_yaw = torch.cat((curr_state[..., 0:2], curr_yaw), dim=-1) + + # tgt = tgt - curr_pos_yaw.repeat(1, 1, tgt.size(2), 1) * tgt_mask.unsqueeze(-1) + + + return ( + src, + dyn_type, + src_state, + src_pos, + src_yaw, + src_world_yaw, + src_vel, + raw_type, + src_mask, + src_lanes, + extents, + tgt_pos_yaw, + tgt_mask, + curr_state, + ) + +def obtain_goal_state(tgt_pos_yaw,tgt_mask): + num = torch.arange(0, tgt_mask.shape[2]).view(1, 1, -1).to(tgt_mask.device) + nummask = num * tgt_mask + last_idx, _ = torch.max(nummask, dim=2, keepdim=True) + last_mask = nummask.ge(last_idx) + + goal_mask = tgt_mask*last_mask + goal_pos_yaw = tgt_pos_yaw*goal_mask.unsqueeze(-1) + return goal_pos_yaw[...,:2], goal_pos_yaw[...,2:], goal_mask + + +def batch_to_raw_all_agents(data_batch, step_time): + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["all_other_agents_hist_pos"], + ), + dim=1, + ) + # history position and yaw need to be flipped so that they go from past to recent + src_pos = torch.flip(src_pos, dims=[-2]) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["all_other_agents_hist_yaw"], + ), + dim=1, + ) + src_yaw = torch.flip(src_yaw, dims=[-2]) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["all_other_agents_history_availability"], + ), + dim=1, + ).bool() + + src_mask = torch.flip(src_mask, dims=[-1]) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + # estimate velocity + src_vel = dynamics.Unicycle(step_time).calculate_vel(src_pos, src_yaw, src_mask) + src_vel[:, 0, -1] = data_batch["curr_speed"].unsqueeze(-1) + + return { + "hist_pos": src_pos, + "hist_yaw": src_yaw, + "curr_speed": src_vel[:, :, -1, 0], + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": extents, + } + + +def batch_to_target_all_agents(data_batch): + pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["all_other_agents_future_positions"], + ), + dim=1, + ) + yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["all_other_agents_future_yaws"], + ), + dim=1, + ) + avails = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["all_other_agents_future_availability"], + ), + dim=1, + ) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + return { + "fut_pos": pos, + "fut_yaw": yaw, + "fut_mask": avails, + "extents": extents + } + + +def generate_edges( + raw_type, + extents, + pos_pred, + yaw_pred, +): + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + + agent_mask = veh_mask | ped_mask + edge_types = ["VV", "VP", "PV", "PP"] + edges = {et: list() for et in edge_types} + for i in range(agent_mask.shape[0]): + agent_idx = torch.where(agent_mask[i] != 0)[0] + edge_idx = torch.combinations(agent_idx, r=2) + VV_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + VP_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + PV_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + PP_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + if pos_pred.ndim == 4: + edges_of_all_types = torch.cat( + ( + pos_pred[i, edge_idx[:, 0], :], + yaw_pred[i, edge_idx[:, 0], :], + pos_pred[i, edge_idx[:, 1], :], + yaw_pred[i, edge_idx[:, 1], :], + extents[i, edge_idx[:, 0]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + extents[i, edge_idx[:, 1]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[VV_idx]) + edges["VP"].append(edges_of_all_types[VP_idx]) + edges["PV"].append(edges_of_all_types[PV_idx]) + edges["PP"].append(edges_of_all_types[PP_idx]) + elif pos_pred.ndim == 5: + + edges_of_all_types = torch.cat( + ( + pos_pred[i, :, edge_idx[:, 0], :], + yaw_pred[i, :, edge_idx[:, 0], :], + pos_pred[i, :, edge_idx[:, 1], :], + yaw_pred[i, :, edge_idx[:, 1], :], + extents[i, None, edge_idx[:, 0], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + extents[i, None, edge_idx[:, 1], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[:, VV_idx]) + edges["VP"].append(edges_of_all_types[:, VP_idx]) + edges["PV"].append(edges_of_all_types[:, PV_idx]) + edges["PP"].append(edges_of_all_types[:, PP_idx]) + if pos_pred.ndim == 4: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=0) + elif pos_pred.ndim == 5: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=1) + return edges + + +def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B,N,T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + veh_mask = (raw_types >= 3) & (raw_types <= 13) + ped_mask = (raw_types == 14) | (raw_types == 15) + + edges = torch.zeros([B,N,A,T,10]).to(ego_trajectories.device) + edges[...,:3] = ego_trajectories.unsqueeze(2).repeat(1,1,A,1,1) + if agent_trajectories.ndim==4: + edges[...,3:6] = agent_trajectories.unsqueeze(1).repeat(1,N,1,1,1) + else: + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,1,2).repeat(1,N,A,T,1) + edges[...,8:] = agent_extents.reshape(B,1,A,1,2).repeat(1,N,1,T,1) + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def gen_EC_edges(ego_trajectories,agent_trajectories,ego_extents, agent_extents, raw_types,mask=None): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,A,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + mask (optional, torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + + B,A = ego_trajectories.shape[:2] + T = ego_trajectories.shape[-2] + + veh_mask = (raw_types >= 3) & (raw_types <= 13) + ped_mask = (raw_types == 14) | (raw_types == 15) + + + if ego_trajectories.ndim==4: + edges = torch.zeros([B,A,T,10]).to(ego_trajectories.device) + edges[...,:3] = ego_trajectories + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,1,T,1) + elif ego_trajectories.ndim==5: + + K = ego_trajectories.shape[2] + edges = torch.zeros([B,A*K,T,10]).to(ego_trajectories.device) + edges[...,:3] = TensorUtils.join_dimensions(ego_trajectories,1,3) + edges[...,3:6] = agent_trajectories.repeat(1,K,1,1) + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A*K,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,K,T,1) + veh_mask = veh_mask.tile(1,K) + ped_mask = ped_mask.tile(1,K) + if mask is not None: + veh_mask = veh_mask*mask + ped_mask = ped_mask*mask + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + + # Use predicted ego position to compute future box edges + + targets_all = batch_to_target_all_agents(data_batch) + if ego_predictions is not None: + targets_all["fut_pos"] [:, 0, :, :] = ego_predictions["positions"] + targets_all["fut_yaw"][:, 0, :, :] = ego_predictions["yaws"] + elif all_predictions is not None: + targets_all["fut_pos"] = all_predictions["positions"] + targets_all["fut_yaw"] = all_predictions["yaws"] + else: + raise ValueError("Please specify either ego prediction or all predictions") + + pred_edges = generate_edges( + raw_type, targets_all["extents"], + pos_pred=targets_all["fut_pos"], + yaw_pred=targets_all["fut_yaw"] + ) + return pred_edges + + +def get_last_available_index(avails): + """ + Args: + avails (torch.Tensor): target availabilities [B, (A), T] + + Returns: + last_indices (torch.Tensor): index of the last available frame + """ + num_frames = avails.shape[-1] + inds = torch.arange(0, num_frames).to(avails.device) # [T] + inds = (avails > 0).float() * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(dim=-1)[1] # [B, (A)] calculate the index of the last availale frame + return last_inds + + +def get_current_states(batch: dict, dyn_type: dynamics.DynType) -> torch.Tensor: + bs = batch["curr_speed"].shape[0] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, 6).to(batch["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] + current_states[:, 3] = batch["curr_speed"].abs() + current_states[:, [4]] = (batch["hist_yaw"][:, 0] - batch["hist_yaw"][:, 1]).abs() + current_states[:, 5] = batch["extent"][:, 0] # [veh_len] + else: + current_states = torch.zeros(bs, 4).to(batch["curr_speed"].device) # [x, y, vel, yaw] + current_states[:, 2] = batch["curr_speed"] + return current_states + + +def get_current_states_all_agents(batch: dict, step_time, dyn_type: dynamics.DynType) -> torch.Tensor: + if batch["hist_pos"].ndim==3: + state_all = batch_to_raw_all_agents(batch, step_time) + else: + state_all = batch + bs, na = state_all["curr_speed"].shape[:2] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, na, 6).to(state_all["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] + current_states[:, :, :2] = state_all["hist_pos"][:, :, 0] + current_states[:, :, 3] = state_all["curr_speed"].abs() + current_states[:, :, [4]] = (state_all["hist_yaw"][:, :, 0] - state_all["hist_yaw"][:, :, 1]).abs() + current_states[:, :, 5] = state_all["extent"][:, :, 0] # [veh_len] + else: + current_states = torch.zeros(bs, na, 4).to(state_all["curr_speed"].device) # [x, y, vel, yaw] + current_states[:, :, :2] = state_all["hist_pos"][:, :, 0] + current_states[:, :, 2] = state_all["curr_speed"] + current_states[:,:,3:] = state_all["hist_yaw"][:,:,0] + return current_states + + +def get_drivable_region_map(rasterized_map): + return rasterized_map[..., -3, :, :] < 1. + + +def get_modality_shapes(cfg: ExperimentConfig): + assert cfg.env.rasterizer.map_type == "py_semantic" + num_channels = (cfg.stack.history_num_frames + 1) * 2 + 3 + h, w = cfg.env.rasterizer.raster_size + return dict(image=(num_channels, h, w)) \ No newline at end of file diff --git a/diffstack/utils/lane_utils.py b/diffstack/utils/lane_utils.py new file mode 100644 index 0000000..a7db498 --- /dev/null +++ b/diffstack/utils/lane_utils.py @@ -0,0 +1,553 @@ +import numpy as np +import torch +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.geometry_utils import ratan2 +import enum +from dataclasses import dataclass +import scipy.interpolate as spint + + +class LaneModeConst: + X_ahead_thresh = 5.0 + X_rear_thresh = 0.0 + Y_near_thresh = 1.8 + Y_far_thresh = 5.0 + psi_thresh = np.pi / 4 + longitudinal_scale = 30 + lateral_scale = 1 + heading_scale = 0.5 + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def get_edge(lane, dir, W=2.0, num_pts=None): + if dir == "L": + if lane.left_edge is not None: + if num_pts is not None: + lane.left_edge = lane.left_edge.interpolate(num_pts) + xy = lane.left_edge.xy + if lane.left_edge.has_heading: + h = lane.left_edge.h + else: + # check if the points are reversed + edge_angle = np.arctan2(xy[-1, 1] - xy[0, 1], xy[-1, 0] - xy[0, 0]) + center_angle = np.arctan2( + lane.center.xy[-1, 1] - lane.center.xy[0, 1], + lane.center.xy[-1, 0] - lane.center.xy[0, 0], + ) + if np.abs(GeoUtils.round_2pi(edge_angle - center_angle)) > np.pi / 2: + xy = np.flip(xy, 0) + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + else: + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + angle = lane.center.h + np.pi / 2 + offset = np.stack([W * np.cos(angle), W * np.sin(angle)], -1) + xy = lane.center.xy + offset + h = lane.center.h + elif dir == "R": + if lane.right_edge is not None: + if num_pts is not None: + lane.right_edge = lane.right_edge.interpolate(num_pts) + xy = lane.right_edge.xy + if lane.right_edge.has_heading: + h = lane.right_edge.h + else: + # check if the points are reversed + edge_angle = np.arctan2(xy[-1, 1] - xy[0, 1], xy[-1, 0] - xy[0, 0]) + center_angle = np.arctan2( + lane.center.xy[-1, 1] - lane.center.xy[0, 1], + lane.center.xy[-1, 0] - lane.center.xy[0, 0], + ) + if np.abs(GeoUtils.round_2pi(edge_angle - center_angle)) > np.pi / 2: + xy = np.flip(xy, 0) + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + else: + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + angle = lane.center.h - np.pi / 2 + offset = np.stack([W * np.cos(angle), W * np.sin(angle)], -1) + xy = lane.center.xy + offset + h = lane.center.h + elif dir == "C": + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + xy = lane.center.xy + if lane.center.has_heading: + h = lane.center.h + else: + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + return xy, h + + +def get_bdry_xyh(lane1, lane2=None, dir="L", W=3.6, num_pts=25): + if lane2 is None: + xy, h = get_edge(lane1, dir, W, num_pts) + else: + xy1, h1 = get_edge(lane1, dir, W, num_pts) + xy2, h2 = get_edge(lane2, dir, W, num_pts) + xy = np.concatenate((xy1, xy2), 0) + h = np.concatenate((h1, h2), 0) + return xy, h + + +def LaneRelationFromCfg(lane_relation): + if lane_relation == "SimpleLaneRelation": + return SimpleLaneRelation + elif lane_relation == "LRLaneRelation": + return LRLaneRelation + elif lane_relation == "LaneRelation": + return LaneRelation + else: + raise ValueError("Invalid lane relation type") + + +class SimpleLaneRelation(enum.IntEnum): + """ + Categorical token describing the relationship between an agent and a Lane, unitary lane mode that only considers which lane the agent is one + """ + + NOTON = 0 # (0, 2, 3, 4, 5, 6) + ON = 1 + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + return_all_margins=False, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(SimpleLaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., SimpleLaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin = fill_margin( + margin, + noton_idx=SimpleLaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + if force_select: + # offset the margin to make sure that at least one is positive + margin_max = margin[..., SimpleLaneRelation.ON].max(dim=1)[0] + margin_offset = -margin_max.clip(max=0).detach() + 1e-6 + margin[..., SimpleLaneRelation.ON] = ( + margin[..., SimpleLaneRelation.ON] + margin_offset[:, None, :] + ) + + if force_unique: + second_largest_margin = margin[..., 1].topk(2, 1, sorted=True)[0][:, 1] + margin_offset = -second_largest_margin.clip(min=0).detach() + margin[..., 1] = margin[..., 1] + margin_offset[:, None, :] + margin[..., SimpleLaneRelation.NOTON] = -margin[..., SimpleLaneRelation.ON] + + flag = get_flag(margin) + return flag, margin + + +class LRLaneRelation(enum.IntEnum): + """Unitary lane mode that considers which lane the agent is on, which lane the ego is on the left of, and which lane the ego is on the right of""" + + NOTON = 0 # (0, 2, 3, 4) + ON = 1 + LEFTOF = 2 # (5) + RIGHTOF = 3 # (6) + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(LRLaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., LRLaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LRLaneRelation.LEFTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_left_near, y_left_far, psi_margin], -1 + ).min(-1)[ + 0 + ] # further than left near, closer than left far + margin[..., LRLaneRelation.RIGHTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_right_near, y_right_far, psi_margin], + -1, + ).min(-1)[0] + + margin = fill_margin( + margin, + noton_idx=LRLaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + if force_select: + # offset the margin to make sure that at least one is positive + margin_max = margin[..., LRLaneRelation.ON].max(dim=1)[0] + margin_offset = -margin_max.clip(max=0).detach() + 1e-6 + margin[..., LRLaneRelation.ON] = ( + margin[..., LRLaneRelation.ON] + margin_offset[:, None, :] + ) + + if force_unique: + second_largest_margin = margin[..., LRLaneRelation.ON].topk( + 2, 1, sorted=True + )[0][:, 1] + margin_offset = -second_largest_margin.clip(min=0).detach() + margin[..., LRLaneRelation.ON] = ( + margin[..., LRLaneRelation.ON] + margin_offset[:, None, :] + ) + idx_excl_noton = torch.arange(margin.shape[-1]) + idx_excl_noton = idx_excl_noton[idx_excl_noton != LRLaneRelation.NOTON] + margin[..., LRLaneRelation.NOTON] = -margin[..., idx_excl_noton].max(-1)[0] + flag = get_flag(margin) + return flag, margin + + +class LaneRelation(enum.IntEnum): + """ + pairwise lane mode that gives each agent-lane pair a categorical token describing the relationship between the agent and the lane + """ + + NOTON = 0 + ON = 1 + AHEAD = 2 + BEHIND = 3 + MISALIGN = 4 + LEFTOF = 5 + RIGHTOF = 6 + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(LaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., LaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.AHEAD] = torch.stack( + [-x_ahead_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.BEHIND] = torch.stack( + [-x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.MISALIGN] = torch.stack( + [y_left_near, y_right_near, -psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.LEFTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_left_near, y_left_far, psi_margin], -1 + ).min(-1)[ + 0 + ] # further than left near, closer than left far + margin[..., LaneRelation.RIGHTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_right_near, y_right_far, psi_margin], + -1, + ).min(-1)[0] + + margin = fill_margin( + margin, + noton_idx=LaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + flag = get_flag(margin) + return flag, margin + + +def get_l2a_geometry(agent_xysc, lane_xysc, t_range=None, const_override={}): + const = LaneModeConst(**const_override) + # agent_xysc:[B,T,4], lane_xysc:[B,M,L,4] + # lane_mask: [B,M], agent_mask: [B,T] + B, T = agent_xysc.shape[:2] + M, L = lane_xysc.shape[1:3] + + # idx1 = max(int(T*0.3),1) + # idx2 = min(T-idx1,T-1) + + dx = GeoUtils.batch_proj_xysc( + agent_xysc.repeat_interleave(M, 0).reshape(-1, 4), + lane_xysc.repeat_interleave(T, 1).reshape(-1, L, 4), + ).reshape( + B, M, T, L, -1 + ) # [B,M,T,L,xdim] + close_idx = ( + dx[..., 0].abs().argmin(-1) + ) # Take first element (x-pos), and find closest index (of L) within each lane segment + proj_pts = dx.gather(-2, close_idx[..., None, None].repeat(1, 1, 1, 1, 4)).squeeze( + -2 + ) # Get projection points using the closest point for each lane seg [B,M,T,4] + psi = ratan2(proj_pts[..., 2], proj_pts[..., 3]).detach() + y_dev = proj_pts[..., 1] + # Hausdorff-like distance + x_ahead_margin = ( + const.X_ahead_thresh + dx[..., 0].max(-1)[0] + ) / const.longitudinal_scale # We only have to check the minimal value + x_behind_margin = ( + -dx[..., 0].min(-1)[0] + const.X_rear_thresh + ) / const.longitudinal_scale + y_left_near = (const.Y_near_thresh + y_dev) / const.lateral_scale + y_left_far = (const.Y_far_thresh + y_dev) / const.lateral_scale + y_right_near = (const.Y_near_thresh - y_dev) / const.lateral_scale + y_right_far = (const.Y_far_thresh - y_dev) / const.lateral_scale + psi_margin = (const.psi_thresh - psi.abs()) / const.heading_scale + if t_range is not None: + t0, t1 = t_range + x_ahead_margin = x_ahead_margin[:, :, t0:t1].mean(dim=2) + x_behind_margin = x_behind_margin[:, :, t0:t1].mean(dim=2) + y_left_near = y_left_near[:, :, t0:t1].mean(dim=2) + y_left_far = y_left_far[:, :, t0:t1].mean(dim=2) + y_right_near = y_right_near[:, :, t0:t1].mean(dim=2) + y_right_far = y_right_far[:, :, t0:t1].mean(dim=2) + psi_margin = psi_margin[:, :, t0:t1].mean(dim=2) + return ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) + + +def get_ypsi_dev(agent_xysc, lane_xysc): + # agent_xysc:[B,T,4], lane_xysc:[B,M,L,4] + # lane_mask: [B,M], agent_mask: [B,T] + B, T = agent_xysc.shape[:2] + M, L = lane_xysc.shape[1:3] + + # idx1 = max(int(T*0.3),1) + # idx2 = min(T-idx1,T-1) + + dx = GeoUtils.batch_proj_xysc( + agent_xysc.repeat_interleave(M, 0).reshape(-1, 4), + lane_xysc.repeat_interleave(T, 1).reshape(-1, L, 4), + ).reshape( + B, M, T, L, -1 + ) # [B,M,T,L,xdim] + close_idx = ( + dx[..., 0].abs().argmin(-1) + ) # Take first element (x-pos), and find closest index (of L) within each lane segment + proj_pts = dx.gather(-2, close_idx[..., None, None].repeat(1, 1, 1, 1, 4)).squeeze( + -2 + ) # Get projection points using the closest point for each lane seg [B,M,T,4] + psi = ratan2(proj_pts[..., 2], proj_pts[..., 3]).detach() + y_dev = proj_pts[..., 1] + return y_dev, psi + + +def fill_margin(margin, noton_idx, agent_mask=None, lane_mask=None): + idx_excl_noton = torch.arange(margin.shape[-1]) + idx_excl_noton = idx_excl_noton[idx_excl_noton != noton_idx] + # put anything that does not belong to all the classes above to NOTON + margin[..., noton_idx] = -margin[..., idx_excl_noton].max(-1)[ + 0 + ] # Negation of the max of the rest + + # Put anything that is masked out to NOTON + margin[..., noton_idx] = margin[..., noton_idx].masked_fill( + torch.logical_not(agent_mask).unsqueeze(1), 10 + ) # agents we're not considering set to noton + margin[..., idx_excl_noton] = margin[..., idx_excl_noton].masked_fill( + torch.logical_not(agent_mask)[:, None, :, None], -10 + ) + margin[..., noton_idx] = margin[..., noton_idx].masked_fill( + torch.logical_not(lane_mask).unsqueeze(2), 10 + ) + margin[..., idx_excl_noton] = margin[..., idx_excl_noton].masked_fill( + torch.logical_not(lane_mask)[:, :, None, None], -10 + ) + return margin + + +def get_flag(margin): + flag = (margin >= 0).float() + # HACK: suppress multiple on flags + # if flag.shape[-1] > 2: + # flag[..., 2:] = flag[..., 2:].masked_fill( + # (flag[..., 1:2] > 0).repeat_interleave(flag.shape[-1] - 2, -1), 0 + # ) + # assert (flag.sum(-1) == 1).all() + return flag + + +def get_ref_traj(agent_xyh, lane_xyh, des_vel, dt, T): + """calculate reference trajectory along the lane center + + Args: + agent_xyvsc (np.ndarray): B,3 + lane_xysc (np.ndarray): B,L,3 + des_vel (np.ndarray): B + """ + B = agent_xyh.shape[0] + delta_x, _, _ = GeoUtils.batch_proj(agent_xyh, lane_xyh) + indices = np.abs(delta_x).argmin(-1) + xrefs = list() + for agent, lane, idx, delta_x_i, vel in zip( + agent_xyh, lane_xyh, indices, delta_x, des_vel + ): + idx = min(idx, lane.shape[0] - 2) + s = np.linalg.norm( + lane[idx + 1 :, :2] - lane[idx:-1, :2], axis=-1, keepdims=True + ).cumsum() + s = np.insert(s, 0, 0.0) - delta_x_i[idx] + f = spint.interp1d( + s, + lane[idx:], + axis=0, + assume_sorted=True, + bounds_error=False, + fill_value="extrapolate", + ) + xref = f(vel * np.arange(1, T + 1) * dt) + if np.isnan(xref).any(): + xref = np.zeros((T, 3)) + xrefs.append(xref) + return np.stack(xrefs, 0) + + +def get_closest_lane_pts(xyh, lane): + delta_x, _, _ = GeoUtils.batch_proj(xyh, lane.center.xyh) + if isinstance(xyh, np.ndarray): + return np.abs(delta_x).argmin() + elif isinstance(xyh, torch.Tensor): + return delta_x.abs().argmin() + + +def test_edge(): + import pickle + import torch + + with open("sf_test.pkl", "rb") as f: + data = pickle.load(f) + lane_xyh = data["lane_xyh"] + lane_feat = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + ego_xycs = data["agent_hist"][:, 0, :, [0, 1, 6, 7]] + + get_l2a_geometry(ego_xycs, lane_feat, [0, 4]) + + print("123") + + +if __name__ == "__main__": + test_edge() diff --git a/diffstack/utils/log_utils.py b/diffstack/utils/log_utils.py new file mode 100644 index 0000000..dc755ab --- /dev/null +++ b/diffstack/utils/log_utils.py @@ -0,0 +1,84 @@ +""" +This file contains utility classes and functions for logging to stdout, stderr, +and to tensorboard. +""" +import sys +import os +import time +import pathlib +import json +import wandb +from omegaconf import OmegaConf, DictConfig + + +def prepare_logging(cfg: DictConfig, use_logging=True, use_wandb=None, verbose=True, wandb_init_retries=10): + # Logging + log_writer = None + model_dir = None + if not use_logging: + return log_writer, model_dir + if use_wandb is None: + use_wandb = cfg.run.wandb_logging + + # Create the log and model directory if they're not present. + model_dir = os.path.join(cfg.run.log_dir, + cfg.run.log_tag + time.strftime('-%d_%b_%Y_%H_%M_%S', time.localtime())) + cfg.logdir = model_dir + + pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) + + # Save config to model directory + cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False) + with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json: + json.dump(cfg_dict, conf_json) + + # wandb.tensorboard.patch(root_logdir=model_dir, pytorch=True) + + # WandB init. Put it in a loop because it can fail on ngc. + if use_wandb: + log_writer = init_wandb(cfg, wandb_init_retries=wandb_init_retries) + + if verbose: + print (f"Log path: {model_dir}") + + return log_writer, model_dir + + +def init_wandb(cfg, wandb_init_retries=10): + cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False) + for _ in range(wandb_init_retries): + try: + log_writer = wandb.init( + project=cfg.run.wandb_project, + name=f"{cfg.run['log_tag']}", + config=cfg_dict, + mode="offline" if cfg.run["interactive"] else "online", # sync_tensorboard=True, + settings=wandb.Settings(start_method="thread"), + ) + except: + continue + break + else: + raise ValueError("Could not connect to wandb") + return log_writer + + +class PrintLogger(object): + """ + This class redirects print statements to both console and a file. + """ + def __init__(self, log_file): + self.terminal = sys.stdout + print('STDOUT will be forked to %s' % log_file) + self.log_file = open(log_file, "a") + + def write(self, message): + self.terminal.write(message) + self.log_file.write(message) + self.log_file.flush() + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass diff --git a/diffstack/utils/loss_utils.py b/diffstack/utils/loss_utils.py new file mode 100644 index 0000000..bf13f88 --- /dev/null +++ b/diffstack/utils/loss_utils.py @@ -0,0 +1,858 @@ +""" +This file contains a collection of useful loss functions for use with torch tensors. +Partially borrowed from https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/loss_utils.py +""" +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import ( + VEH_VEH_collision, + VEH_PED_collision, + PED_VEH_collision, + PED_PED_collision, +) +import torch.nn.functional as F + + +def cosine_loss(preds, labels): + """ + Cosine loss between two tensors. + Args: + preds (torch.Tensor): torch tensor + labels (torch.Tensor): torch tensor + Returns: + loss (torch.Tensor): cosine loss + """ + sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels) + return -torch.mean(sim - 1.0) + + +def KLD_0_1_loss(mu, logvar): + """ + KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that + this function averages across the batch dimension, but sums across dimension. + Args: + mu (torch.Tensor): mean tensor of shape (B, D) + logvar (torch.Tensor): logvar tensor of shape (B, D) + Returns: + loss (torch.Tensor): KL divergence loss between the input gaussian distribution + and N(0, 1) + """ + return -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean() + + +def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2): + """ + KL divergence loss between two Gaussian distributions. This function + computes the average loss across the batch. + Args: + mu_1 (torch.Tensor): first means tensor of shape (B, D) + logvar_1 (torch.Tensor): first logvars tensor of shape (B, D) + mu_2 (torch.Tensor): second means tensor of shape (B, D) + logvar_2 (torch.Tensor): second logvars tensor of shape (B, D) + Returns: + loss (torch.Tensor): KL divergence loss between the two gaussian distributions + """ + return ( + -0.5 + * ( + 1.0 + + logvar_1 + - logvar_2 + - ((mu_2 - mu_1).pow(2) / logvar_2.exp()) + - (logvar_1.exp() / logvar_2.exp()) + ) + .sum(dim=1) + .mean() + ) + + +def KLD_discrete(logp, logq): + """KL divergence loss between two discrete distributions. This function + computes the average loss across the batch. + + Args: + logp (torch.Tensor): log probability of first discrete distribution (B,D) + logq (torch.Tensor): log probability of second discrete distribution (B,D) + """ + return (torch.exp(logp) * (logp - logq)).sum(dim=1) + + +def KLD_discrete_with_zero(p, q, logmin=None, logmax=None): + flag = p != 0 + # breakpoint() + p = p.clip(min=1e-8) + q = q.clip(min=1e-8) + logp = (torch.log(p) * flag).nan_to_num(0) + logq = (torch.log(q) * flag).nan_to_num(0) + logp = logp.clip(min=logmin, max=logmax) + logq = logq.clip(min=logmin, max=logmax) + return (p * (logp - logq) * flag).sum(dim=1) + + +def log_normal(x, m, v, avails=None): + """ + Log probability of tensor x under diagonal multivariate normal with + mean m and variance v. The last dimension of the tensors is treated + as the dimension of the Gaussian distribution - all other dimensions + are treated as independent Gaussians. Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, ..., D) + m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D) + v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D) + avails (torch.Tensor): availability of x and m + Returns: + log_prob (torch.Tensor): log probabilities of shape (B, ...) + """ + if avails is None: + element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi)) + else: + element_wise = -0.5 * ( + torch.log(v) + ((x - m) * avails).pow(2) / v + np.log(2 * np.pi) + ) + log_prob = element_wise.sum(-1) + return log_prob + + +def log_normal_mixture(x, m, v, w=None, log_w=None): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + w (torch.Tensor): weights tensor - if provided, should be + shape (B, M) or (1, M) + log_w (torch.Tensor): log-weights tensor - if provided, should be + shape (B, M) or (1, M) + Returns: + log_prob (torch.Tensor): log probabilities of shape (B,) + """ + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + log_prob = log_normal(x, m, v) + if w is not None or log_w is not None: + # this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i)) + if w is not None: + assert log_w is None + log_w = torch.log(w) + log_prob += log_w + # then compute log sum_i exp [log(w_i * N(x | m_i, v_i))] + # (B, M) -> (B,) + log_prob = log_sum_exp(log_prob, dim=1) + else: + # (B, M) -> (B,) + log_prob = log_mean_exp(log_prob, dim=1) # mean accounts for uniform weights + return log_prob + + +def NLL_GMM_loss(x, m, v, pi, avails=None, detach=True, mode="sum"): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + logpi (torch.Tensor): log probability of the modes (B,M) + detach (bool): option whether to detach all modes but the best one + mode (string): mode of loss, sum or max + + Returns: + -log_prob (torch.Tensor): log probabilities of shape (B,) + """ + if v is None: + v = torch.ones_like(m) + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + if avails is not None: + avails = avails.unsqueeze(1) + log_prob = log_normal(x, m, v, avails=avails) + if mode == "sum": + if detach: + max_flag = log_prob == log_prob.max(dim=1, keepdim=True)[0] + nonmax_flag = torch.logical_not(max_flag) + log_prob_detach = log_prob.detach() + NLL_loss = (-pi * log_prob * max_flag).sum(1).mean() + ( + -pi * log_prob_detach * nonmax_flag + ).sum(1).mean() + else: + NLL_loss = (-pi * log_prob).sum(1).mean() + elif mode == "max": + max_flag = log_prob == log_prob.max(dim=1, keepdim=True)[0] + NLL_loss = (-pi * log_prob * max_flag).sum(1).mean() + return NLL_loss + + +def log_mean_exp(x, dim): + """ + Compute the log(mean(exp(x), dim)) in a numerically stable manner. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): a tensor + dim (int): dimension along which mean is computed + Returns: + y (torch.Tensor): log(mean(exp(x), dim)) + """ + return log_sum_exp(x, dim) - np.log(x.size(dim)) + + +def log_sum_exp(x, dim=0): + """ + Compute the log(sum(exp(x), dim)) in a numerically stable manner. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): a tensor + dim (int): dimension along which sum is computed + Returns: + y (torch.Tensor): log(sum(exp(x), dim)) + """ + + max_x = torch.max(x, dim)[0] + new_x = x - max_x.unsqueeze(dim).expand_as(x) + + return max_x + (new_x.exp().sum(dim)).log() + + +def project_values_onto_atoms(values, probabilities, atoms): + """ + Project the categorical distribution given by @probabilities on the + grid of values given by @values onto a grid of values given by @atoms. + This is useful when computing a bellman backup where the backed up + values from the original grid will not be in the original support, + requiring L2 projection. + Each value in @values has a corresponding probability in @probabilities - + this probability mass is shifted to the closest neighboring grid points in + @atoms in proportion. For example, if the value in question is 0.2, and the + neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to + atom 0 and 0.2 of the probability weight will go to 1. + Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42 + + Args: + values: value grid to project, of shape (batch_size, n_atoms) + probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms) + atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms) + Returns: + new probability vectors that correspond to the L2 projection of the categorical distribution + onto @atoms + """ + + # make sure @atoms is shape (n_atoms,) + if len(atoms.shape) > 1: + atoms = atoms.squeeze(0) + + # helper tensors from @atoms + vmin, vmax = atoms[0], atoms[1] + d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:] + d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1] + + # ensure that @values grid is within the support of @atoms + clipped_values = values.clamp(min=vmin, max=vmax)[ + :, None, : + ] # (batch_size, 1, n_atoms) + clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1) + + # distance between atom values in support + d_pos = (d_pos - atoms)[ + None, :, None + ] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1) + d_neg = (atoms - d_neg)[ + None, :, None + ] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1) + + # distances between all pairs of grid values + deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms) + + # computes eqn (7) in distributional RL paper by doing the following - for each + # output atom in @atoms, consider values that are close enough, and weight their + # probability mass contribution by the normalized distance in [0, 1] given + # by (1. - (z_j - z_i) / (delta_z)). + d_sign = (deltas >= 0.0).float() + delta_hat = (d_sign * deltas / d_pos) - ((1.0 - d_sign) * deltas / d_neg) + delta_hat = (1.0 - delta_hat).clamp(min=0.0, max=1.0) + probabilities = probabilities[:, None, :] + return (delta_hat * probabilities).sum(dim=2) + + +def trajectory_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + assert availabilities.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = availabilities.unsqueeze(-1) * weights_scaling + loss = torch.mean(crit(predictions, targets) * target_weights) + return loss + + +def MultiModal_trajectory_loss( + predictions, + targets, + availabilities, + prob, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), + calc_goal_reach=False, + gamma=None, + detach_nonopt=True, +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, M, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (M), (A), T] + prob (torch.Tensor): [B, (A), M] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + gamma (float): risk level for CVAR + detach_nonopt (Bool): if detaching the loss for the non-optimal mode + + Returns: + loss (torch.Tensor) + """ + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + bs, M, Na = predictions.shape[:3] + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = availabilities.unsqueeze(-1) * weights_scaling + if availabilities.ndim < targets.ndim: + target_weights = target_weights.unsqueeze(1) + loss_v = ( + crit(predictions, targets.unsqueeze(1).repeat_interleave(M, 1)) * target_weights + ) + loss_agg_agent = loss_v.reshape(bs, M, Na, -1).sum(-1).transpose(1, 2) + loss_agg = loss_v.reshape(bs, M, -1).sum(-1) + + if gamma is not None: + if prob.ndim == 2: + p_cvar = list() + for i in range(bs): + p_cvar.append( + CVaR_weight(loss_agg[i], prob[i], gamma, sign=1, end_idx=None) + ) + p_cvar = torch.stack(p_cvar, 0) + elif prob.ndim == 3: + p_cvar = list() + for i in range(bs): + p_cvar_i = list() + for j in range(Na): + p_cvar_i.append( + CVaR_weight( + loss_agg_agent[i, j], + prob[i, j], + gamma, + sign=1, + end_idx=None, + ) + ) + p_cvar_i = torch.stack(p_cvar_i, 0) + + p_cvar.append(p_cvar_i) + p_cvar = torch.stack(p_cvar, 0) + prob = p_cvar + + if detach_nonopt: + if prob.ndim == 2: + loss_agg_detached = loss_agg.detach() + min_flag = loss_agg == loss_agg.min(dim=1, keepdim=True)[0] + nonmin_flag = torch.logical_not(min_flag) + loss = ( + (loss_agg * min_flag * prob).sum() + + (loss_agg_detached * nonmin_flag * prob).sum() + ) / (availabilities.sum() + 0.01) + elif prob.ndim == 3: + loss_agg_detached = loss_agg_agent.detach() + min_flag = loss_agg_agent == loss_agg_agent.min(dim=2, keepdim=True)[0] + nonmin_flag = torch.logical_not(min_flag) + loss = ( + (loss_agg_agent * min_flag * prob).sum() + + (loss_agg_detached * nonmin_flag * prob).sum() + ) / (availabilities.sum() + 0.01) + else: + if prob.ndim == 2: + loss = (loss_agg * prob).sum() / (availabilities.sum() + 0.01) + elif prob.ndim == 3: + loss = (loss_agg_agent * prob).sum() / (availabilities.sum() + 0.01) + if calc_goal_reach: + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + num_frames = availabilities.shape[-1] + goal_mask = TensorUtils.to_one_hot(last_inds, num_class=num_frames) + goal_mask = goal_mask.unsqueeze(1).unsqueeze(-1) + if detach_nonopt: + if prob.ndim == 2: + goal_loss = ( + ( + loss_v * (min_flag * prob)[..., None, None, None] * goal_mask + ).sum() + + ( + loss_v.detach() + * (nonmin_flag * prob)[..., None, None, None] + * goal_mask + ).sum() + ) / (goal_mask.sum() + 0.01) + elif prob.ndim == 3: + goal_mask = goal_mask.transpose(1, 2) + goal_loss = ( + ( + loss_v.transpose(1, 2) + * (min_flag * prob)[..., None, None] + * goal_mask + ).sum() + + ( + loss_v.transpose(1, 2).detach() + * (nonmin_flag * prob)[..., None, None, None] + * goal_mask + ).sum() + ) / (goal_mask.sum() + 0.01) + else: + if prob.ndim == 2: + goal_loss = (loss_v * prob[..., None, None, None] * goal_mask).sum() / ( + goal_mask.sum() + 0.01 + ) + elif prob.ndim == 3: + goal_mask = goal_mask.transpose(1, 2) + goal_loss = ( + loss_v.transpose(1, 2) * prob[..., None, None] * goal_mask + ).sum() / (goal_mask.sum() + 0.01) + return loss, goal_loss + else: + return loss, None + + +def goal_reaching_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Final step loss between gt and predicted trajectories (normally used in conjunction with a forward dynamics model) + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + # compute loss mask by finding the last available target + num_frames = availabilities.shape[-1] + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + goal_mask = TensorUtils.to_one_hot( + last_inds, num_class=num_frames + ) # [B, (A), T] with the last frame set to 1 + # filter out samples that do not have available frames + available_samples_mask = availabilities.sum(-1) > 0 # [B, (A)] + goal_mask = goal_mask * available_samples_mask.unsqueeze(-1).float() # [B, (A), T] + goal_loss = trajectory_loss( + predictions, + targets, + availabilities=goal_mask, + weights_scaling=weights_scaling, + crit=crit, + ) + return goal_loss + + +def lane_regulation_loss(lane_flag, agent_mask): + return (lane_flag.mean(-1) * agent_mask).sum() / agent_mask.sum() + + +def weighted_trajectory_loss( + predictions, + targets, + target_weights, + total_count, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + weights (torch.Tensor): [B, (A), T] + total_count (float) + weight_scaling (torch.Tensor): [D], Defaults to None. + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + assert target_weights.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = target_weights.unsqueeze(-1) * weights_scaling + loss = torch.sum(crit(predictions, targets) * target_weights) / total_count + return loss + + +def CVaR_weight(val, p, gamma, sign=-1, end_idx=None): + q = torch.clamp(p / gamma, max=1.0) + if end_idx is None: + end_idx = p.shape[0] + assert (p[end_idx:] == 0).all() + remain = 1.0 + if sign == 1: + idx = torch.argsort(val[0:end_idx]) + else: + idx = torch.argsort(val[0:end_idx], descending=True) + i = 0 + for i in range(end_idx): + if q[idx[i]] > remain: + q[idx[i]] = remain + remain = 0.0 + else: + remain -= q[idx[i]] + return q + + +def weighted_multimodal_trajectory_loss( + predictions, + targets, + target_weights, + probability, + total_count, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), + gamma=None, +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, M, A, T, D] + targets (torch.Tensor): target trajectory [B, A, T, D] + target_weights (torch.Tensor): [B, A, T] + probability (torch.Tensor): [B,M] + total_count (float) + weight_scaling (torch.Tensor): [D], Defaults to None. + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + bs, M, A, T = predictions.shape[:4] + assert target_weights.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = target_weights.unsqueeze(-1) * weights_scaling + err = ( + crit( + targets.unsqueeze(1).repeat(1, predictions.size(1), 1, 1, 1), + predictions, + ) + * target_weights.unsqueeze(1) + * probability[:, :, None, None, None] + ) + if gamma is None: + max_idx = torch.max(probability, dim=-1)[1] + max_mask = torch.zeros( + [*err.shape[:2], 1, 1, 1], dtype=torch.bool, device=err.device + ) + max_mask[torch.arange(0, err.size(0)), max_idx] = True + nonmax_mask = ~max_mask + loss = ( + torch.sum((err * max_mask)) + torch.sum((err * nonmax_mask).detach()) + ) / total_count + else: + loss_agg = err.reshape(bs, M, -1).sum(-1) + + p_cvar = list() + for i in range(bs): + p_cvar.append( + CVaR_weight(loss_agg[i], probability[i], gamma, sign=1, end_idx=None) + ) + p_cvar = torch.stack(p_cvar, 0) + loss = (loss_agg * p_cvar).sum() / total_count + return loss + + +def likelihood_loss(likelihood): + return 1.0 - torch.mean(likelihood) + + +def lane_regularization_loss(lane_flags, weights, total_count, probability=None): + """penalizing the vehicle for exiting drivable area + + Args: + lane_flags (torch.Tensor): 1 for in the lane, 0 for out of the lane, [B, (M), (A), T, 1] + weights (torch.Tensor): [B, (A), T] + total_count (float): + probability (torch.Tensor, optional): [B,M]. Defaults to None. + + Returns: + [type]: [description] + """ + if probability is None: + loss = torch.sum(weights.unsqueeze(-1) * (1.0 - lane_flags)) / total_count + else: + if lane_flags.ndim == 4: + probability = probability[:, :, None, None] + elif lane_flags.ndim == 5: + probability = probability[:, :, None, None, None] + loss = ( + torch.sum( + weights.unsqueeze(-1).unsqueeze(1) * (1.0 - lane_flags) * probability + ) + / total_count + ) + return loss + return loss + + +def goal_reaching_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Final step loss between gt and predicted trajectories (normally used in conjunction with a forward dynamics model) + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + # compute loss mask by finding the last available target + num_frames = availabilities.shape[-1] + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + goal_mask = TensorUtils.to_one_hot( + last_inds, num_class=num_frames + ) # [B, (A), T] with the last frame set to 1 + # filter out samples that do not have available frames + available_samples_mask = availabilities.sum(-1) > 0 # [B, (A)] + goal_mask = goal_mask * available_samples_mask.unsqueeze(-1).float() # [B, (A), T] + goal_loss = trajectory_loss( + predictions, + targets, + availabilities=goal_mask, + weights_scaling=weights_scaling, + crit=crit, + ) + return goal_loss + + +def collision_loss( + pred_edges: Dict[str, torch.Tensor], + weight=None, + col_funcs=None, + keepdim=False, + return_dis=False, +): + """ + Calculate collision loss among predicted edges along a batch of trajectories + Args: + pred_edges (dict): A dict that maps collision types to box locations + col_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if col_funcs is None: + col_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_loss = 0.0 + dis_by_type = dict() + for et, fun in col_funcs.items(): + if et not in pred_edges: + continue + edges = pred_edges[et] + if edges.shape[0] == 0: + continue + dis = fun(edges[..., 0:3], edges[..., 3:6], edges[..., 6:8], edges[..., 8:]) + dis_by_type[et] = dis + coll_loss_tensor = ( + torch.sigmoid(-dis * 4).nan_to_num(0.0).sum(dim=-1) + ) # smooth collision loss + if weight is not None: + if keepdim: + coll_loss += coll_loss_tensor * weight / (weight.sum() + 1e-5) + else: + coll_loss += torch.sum(coll_loss_tensor * weight) / ( + weight.sum() + 1e-5 + ) + else: + if keepdim: + coll_loss += coll_loss_tensor + else: + coll_loss += coll_loss_tensor.sum(-1).mean() + if return_dis: + return coll_loss, dis_by_type + else: + return coll_loss + + +def collision_loss_masked(edges, type_mask, weight=None, col_funcs=None, keepdim=False): + if col_funcs is None: + col_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_loss = 0.0 + for k, v in type_mask.items(): + if edges.shape[0] == 0: + continue + dis = col_funcs[k]( + edges[..., 0:3], + edges[..., 3:6], + edges[..., 6:8], + edges[..., 8:], + ).min(dim=-1)[0] + coll_loss_tensor = torch.sigmoid(-dis * 4) * v + if weight is not None: + if keepdim: + coll_loss += coll_loss_tensor * weight / (weight.sum() + 1e-5) + else: + coll_loss += torch.sum(coll_loss_tensor * weight) / ( + weight.sum() + 1e-5 + ) + else: + if keepdim: + coll_loss += coll_loss_tensor + else: + coll_loss += coll_loss_tensor.sum(-1).mean() + return coll_loss + + +def discriminator_loss(likelihood_pred, likelihood_GT): + label = torch.cat( + (torch.zeros_like(likelihood_pred), torch.ones_like(likelihood_GT)), 0 + ) + return F.binary_cross_entropy(torch.cat((likelihood_pred, likelihood_GT)), label) + + +def compute_pred_loss( + recon_loss_type, pred_batch, target_traj, avails, prob, weights_scaling=None +): + if "z" in pred_batch: + z1 = torch.argmax(pred_batch["z"], dim=-1) + else: + z1 = None + if recon_loss_type == "NLL": + assert "logvar" in pred_batch["x_recons"] + bs, M, T, D = pred_batch["trajectories"].shape + var = ( + torch.exp(pred_batch["x_recons"]["logvar"]) + + torch.ones_like(pred_batch["x_recons"]["logvar"]) * 1e-4 + ).reshape(bs, M, -1) + if z1 is not None: + var = torch.gather(var, 1, z1.unsqueeze(-1).repeat(1, 1, var.size(-1))) + avails = ( + avails.unsqueeze(-1).repeat(1, 1, target_traj.shape[-1]).reshape(bs, -1) + ) + pred_loss = NLL_GMM_loss( + x=target_traj.reshape(bs, -1), + m=pred_batch["trajectories"].reshape(bs, M, -1), + v=var, + pi=prob, + avails=avails, + ) + pred_loss = pred_loss.mean() + + elif recon_loss_type == "MSE": + pred_loss, _ = MultiModal_trajectory_loss( + predictions=pred_batch["trajectories"], + targets=target_traj, + availabilities=avails, + prob=prob, + weights_scaling=weights_scaling, + ) + else: + raise NotImplementedError("{} is not implemented".format(recon_loss_type)) + return pred_loss + + +def diversity_score( + pred: torch.tensor, + avail: torch.tensor, + mode: str = "mean", + max_clip=0.4, +) -> torch.tensor: + """ + Compute the distance among trajectory samples at the last timestep + Args: + pred (torch.tensor): array of shape B x M x A x T x 2 + avail (torch.tensor): array of availability B x + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + torch.tensor: average displacement error (ADE) of the batch, an array of float numbers + """ + # compute pairwise distances at the last time step + traj_norm = torch.linalg.norm(pred[..., -1, :] - pred[..., 0, :], dim=-1).clip( + min=0.1 + ) + traj_norm = traj_norm[:, None, :, :, None].detach() + avail = avail.any(-1)[:, None, None, :] + pred = pred[..., -1, :] + error = torch.linalg.norm( + (pred[:, None, :] - pred[:, :, None]) * avail.unsqueeze(-1) / traj_norm, dim=-1 + ) + error = (error.clip(max=max_clip) * avail).sum(-1) / avail.sum(-1).clip(min=1) + # [B, M, M] + error = error.reshape(error.shape[0], -1) # [B, M * M] + if mode == "max": + error = error.max(-1) + elif mode == "mean": + error = error.mean(-1) + else: + raise ValueError(f"mode: {mode} not valid") + return error.mean() + + +def loss_clip(loss, max_loss=2.0): + loss_offset = F.relu(loss - max_loss).detach() + return loss - loss_offset diff --git a/diffstack/utils/math_utils.py b/diffstack/utils/math_utils.py new file mode 100644 index 0000000..ecd7dff --- /dev/null +++ b/diffstack/utils/math_utils.py @@ -0,0 +1,59 @@ +import torch +import numpy as np + +def soft_min(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((y-x)/2) + exp2 = expfun((x-y)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_max(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((x-y)/2) + exp2 = expfun((y-x)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_sat(x,x_min=None,x_max=None,gamma=5): + if x_min is None and x_max is None: + return x + elif x_min is None and x_max is not None: + return soft_min(x,x_max,gamma) + elif x_min is not None and x_max is None: + return soft_max(x,x_min,gamma) + else: + if isinstance(x_min,torch.Tensor) or isinstance(x_min,np.ndarray): + assert (x_max>x_min).all() + else: + assert x_max>x_min + xc = x - (x_min+x_max)/2 + if isinstance(x,torch.Tensor): + return xc/(torch.pow(1+torch.pow(torch.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + elif isinstance(x,np.ndarray): + return xc/(np.power(1+np.power(np.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + else: + raise Exception("data type not supported") + + +def Gaussian_importance_sampling(mu0,var0,mu1,var1,num_samples=1): + """ perform importance sampling between two Gaussian distributions + + Args: + mu0 (torch.Tensor): [B,D]: mean of the target Gaussian distribution + var0 (torch.Tensor): [B,D]: variance of the target Gaussian distribution + mu1 (torch.Tensor): [B,D]: mean of the proposal Gaussian distribution + var1 (torch.Tensor): [B,D]: variance of the proposal Gaussian distribution + num_samples (int, optional): number of samples. Defaults to 1. + + Returns: + samples: [B,num_samples,D]: samples from the proposal Gaussian distribution + log_weights: [B,num_samples]: log weights of the samples + """ + samples = torch.randn([*mu1.shape[:-1],num_samples,mu1.shape[-1]],device=mu1.device)*torch.sqrt(var1).unsqueeze(-2)+mu1.unsqueeze(-2) + log_weights = -0.5*torch.log(2*np.pi*var0.unsqueeze(-2))-0.5*torch.pow(samples-mu0.unsqueeze(-2),2)/var0.unsqueeze(-2)+0.5*torch.log(2*np.pi*var1.unsqueeze(-2))+0.5*torch.pow(samples-mu1.unsqueeze(-2),2)/var1.unsqueeze(-2) + return samples,log_weights diff --git a/diffstack/utils/metrics.py b/diffstack/utils/metrics.py new file mode 100644 index 0000000..a58f810 --- /dev/null +++ b/diffstack/utils/metrics.py @@ -0,0 +1,848 @@ +""" +Adapted from https://github.com/lyft/l5kit/blob/master/l5kit/l5kit/evaluation/metrics.py +""" + + +from typing import Callable, Dict, List, Union +import torch +import numpy as np + +from diffstack.utils.geometry_utils import ( + VEH_VEH_collision, + VEH_PED_collision, + PED_VEH_collision, + PED_PED_collision, + get_box_world_coords, +) +from diffstack.utils.loss_utils import log_normal +from diffstack.modules.predictors.trajectron_utils.model.components.gmm2d import GMM2D + +from trajdata import SceneBatch + + +metric_signature = Callable[ + [np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray +] + + +def _assert_shapes( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> None: + """ + Check the shapes of args required by metrics + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + """ + assert ( + len(pred.shape) == 4 + ), f"expected 3D (BxMxTxC) array for pred, got {pred.shape}" + batch_size, num_modes, future_len, num_coords = pred.shape + + assert ground_truth.shape == ( + batch_size, + future_len, + num_coords, + ), f"expected 2D (Batch x Time x Coords) array for gt, got {ground_truth.shape}" + assert confidences.shape == ( + batch_size, + num_modes, + ), f"expected 2D (Batch x Modes) array for confidences, got {confidences.shape}" + + assert np.allclose(np.sum(confidences, axis=1), 1), "confidences should sum to 1" + assert avails.shape == ( + batch_size, + future_len, + ), f"expected 1D (Time) array for avails, got {avails.shape}" + # assert all data are valid + assert np.isfinite(pred).all(), "invalid value found in pred" + assert np.isfinite(ground_truth).all(), "invalid value found in gt" + assert np.isfinite(confidences).all(), "invalid value found in confidences" + assert np.isfinite(avails).all(), "invalid value found in avails" + + +def batch_neg_multi_log_likelihood( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Compute a negative log-likelihood for the multi-modal scenario. + log-sum-exp trick is used here to avoid underflow and overflow, For more information about it see: + https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations + https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ + https://leimao.github.io/blog/LogSumExp/ + For more details about used loss function and reformulation, please see + https://github.com/lyft/l5kit/blob/master/competition.md. + Args: + ground_truth (np.ndarray): array of shape (batchsize)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batchsize)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape ((batchsize)xmodes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batchsize)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: negative log-likelihood for this batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + + with np.errstate( + divide="ignore" + ): # when confidence is 0 log goes to -inf, but we're fine with it + error = np.log(confidences) - 0.5 * np.sum(error, axis=-1) # reduce timesteps + + # use max aggregator on modes for numerical stability + max_value = np.max( + error, axis=-1, keepdims=True + ) # error are negative at this point, so max() gives the minimum one + error = ( + -np.log(np.sum(np.exp(error - max_value), axis=-1)) - max_value + ) # reduce modes + return error + + +def batch_rmse( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the root mean squared error, computed using the stable nll + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: negative log-likelihood for this batch, an array of float numbers + """ + nll = batch_neg_multi_log_likelihood(ground_truth, pred, confidences, avails) + _, _, future_len, _ = pred.shape + + return np.sqrt(2 * nll / future_len) + + +def batch_prob_true_mode( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the probability of the true mode + Args: + ground_truth (np.ndarray): array of shape (timesteps)x(2D coords) + pred (np.ndarray): array of shape (modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: a (modes) numpy array + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + + with np.errstate( + divide="ignore" + ): # when confidence is 0 log goes to -inf, but we're fine with it + error = np.log(confidences) - 0.5 * np.sum(error, axis=-1) # reduce timesteps + + # use max aggregator on modes for numerical stability + max_value = np.max( + error, axis=-1, keepdims=True + ) # error are negative at this point, so max() gives the minimum one + + error = np.exp(error - max_value) / np.sum(np.exp(error - max_value)) + return error + + +def batch_time_displace( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the displacement at timesteps T + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: a (batch)x(timesteps) numpy array + """ + true_mode_error = batch_prob_true_mode(ground_truth, pred, confidences, avails) + true_mode_error = true_mode_error[:, :, np.newaxis] # add timesteps axis + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + return np.sum(true_mode_error * np.sqrt(error), axis=1) # reduce modes + + +def batch_average_displacement_error( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "mean", +) -> np.ndarray: + """ + Returns the average displacement error (ADE), which is the average displacement over all timesteps. + During calculation, confidences are ignored, and two modes are available: + - oracle: only consider the best hypothesis + - mean: average over all hypotheses + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + error = error ** 0.5 # calculate root of error (= L2 norm) + error = np.mean(error, axis=-1) # average over timesteps + if mode == "oracle": + error = np.min(error, axis=1) # use best hypothesis + elif mode == "mean": + error = np.mean(error*confidences, axis=1) # average over hypotheses + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_final_displacement_error( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "mean", +) -> np.ndarray: + """ + Returns the final displacement error (FDE), which is the displacement calculated at the last timestep. + During calculation, confidences are ignored, and two modes are available: + - oracle: only consider the best hypothesis + - mean: average over all hypotheses + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: final displacement error (FDE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + inds = np.arange(0, pred.shape[2]) + inds = (avails > 0) * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(axis=-1) + last_inds = np.tile(last_inds[:, np.newaxis, np.newaxis],(1,pred.shape[1],1)) + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + error = error ** 0.5 # calculate root of error (= L2 norm) + + # error = error[:, :, -1] # use last timestep + error = np.take_along_axis(error,last_inds,axis=2).squeeze(-1) + if mode == "oracle": + error = np.min(error, axis=-1) # use best hypothesis + elif mode == "mean": + error = np.mean(error, axis=-1) # average over hypotheses + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_average_diversity( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "max", +) -> np.ndarray: + """ + Compute the distance among trajectory samples averaged across time steps + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + # compute pairwise distances + error = np.linalg.norm( + pred[:, np.newaxis, :] - pred[:, :, np.newaxis], axis=-1 + ) # [B, M, M, T] + error = np.mean(error, axis=-1) # average over timesteps + error = error.reshape([error.shape[0], -1]) # [B, M * M] + if mode == "max": + error = np.max(error, axis=-1) + elif mode == "mean": + error = np.mean(error, axis=-1) + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_final_diversity( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "max", +) -> np.ndarray: + """ + Compute the distance among trajectory samples at the last timestep + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + # compute pairwise distances at the last time step + pred = pred[..., -1] + error = np.linalg.norm( + pred[:, np.newaxis, :] - pred[:, :, np.newaxis], axis=-1 + ) # [B, M, M] + error = error.reshape([error.shape[0], -1]) # [B, M * M] + if mode == "max": + error = np.max(error, axis=-1) + elif mode == "mean": + error = np.mean(error, axis=-1) + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def single_mode_metrics( + metrics_func, ground_truth: np.ndarray, pred: np.ndarray, avails: np.ndarray +): + """ + Run a metrics with single mode by inserting a mode dimension + + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(time)x(2D coords) + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): Optional, set to None when not applicable + calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: metrics values + """ + pred = pred[:, None] + conf = np.ones((pred.shape[0], 1)) + kwargs = dict(ground_truth=ground_truth, pred=pred, confidences=conf, avails=avails) + return metrics_func(**kwargs) + + +def batch_pairwise_collision_rate(agent_edges, collision_funcs=None): + """ + Count number of collisions among edge pairs in a batch + Args: + agent_edges (dict): A dict that maps collision types to box locations + collision_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if collision_funcs is None: + collision_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_rates = {} + for et, fun in collision_funcs.items(): + edges = agent_edges[et] + dis = fun( + edges[..., 0:3], + edges[..., 3:6], + edges[..., 6:8], + edges[..., 8:], + ) + dis = dis.min(-1)[0] # reduction over time + if isinstance(dis, np.ndarray): + coll_rates[et] = np.sum(dis <= 0) / float(dis.shape[0]) + else: + coll_rates[et] = torch.sum(dis <= 0) / float(dis.shape[0]) + return coll_rates + + +def batch_pairwise_collision_rate_masked(agent_edges, type_mask,collision_funcs=None): + """ + Count number of collisions among edge pairs in a batch + Args: + agent_edges (dict): A dict that maps collision types to box locations + collision_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if collision_funcs is None: + collision_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + coll_rates = {} + for et, fun in collision_funcs.items(): + if et in type_mask and type_mask[et].sum()>0: + dis = fun( + agent_edges[..., 0:3], + agent_edges[..., 3:6], + agent_edges[..., 6:8], + agent_edges[..., 8:], + ) + dis = dis.min(-1)[0] # reduction over time + if isinstance(dis, np.ndarray): + coll_rates[et] = np.sum((dis <= 0)*type_mask[et]) / type_mask[et].sum() + else: + coll_rates[et] = torch.sum((dis <= 0)*type_mask[et]) / type_mask[et].sum() + return coll_rates + + +def batch_detect_off_road(positions, drivable_region_map): + """ + Compute whether the given positions are out of drivable region + Args: + positions (torch.Tensor): a position (x, y) in rasterized frame [B, ..., 2] + drivable_region_map (torch.Tensor): binary drivable region maps [B, H, W] + + Returns: + off_road (torch.Tensor): whether each given position is off-road [B, ...] + """ + assert positions.shape[0] == drivable_region_map.shape[0] + assert drivable_region_map.ndim == 3 + b, h, w = drivable_region_map.shape + positions_flat = positions[..., 1].long() * w + positions[..., 0].long() + if positions_flat.ndim == 1: + positions_flat = positions_flat[:, None] + drivable = torch.gather( + drivable_region_map.flatten(start_dim=1), # [B, H * W] + dim=1, + index=positions_flat.long().flatten(start_dim=1), # [B, (all trailing dim flattened)] + ).reshape(*positions.shape[:-1]) + return 1 - drivable.float() + + +def batch_detect_off_road_boxes(positions, yaws, extents, drivable_region_map): + """ + Compute whether boxes specified by (@positions, @yaws, and @extents) are out of drivable region. + A box is considered off-road if at least one of its corners are out of drivable region + Args: + positions (torch.Tensor): agent centroid (x, y) in rasterized frame [B, ..., 2] + yaws (torch.Tensor): agent yaws in rasterized frame [B, ..., 1] + extents (torch.Tensor): agent extents [B, ..., 2] + drivable_region_map (torch.Tensor): binary drivable region maps [B, H, W] + + Returns: + box_off_road (torch.Tensor): whether each given box is off-road [B, ...] + """ + boxes = get_box_world_coords(positions, yaws, extents) # [B, ..., 4, 2] + off_road = batch_detect_off_road(boxes, drivable_region_map) # [B, ..., 4] + box_off_road = off_road.sum(dim=-1) > 0.5 + return box_off_road.float() + + +def GMM_loglikelihood(x, m, v, pi, avails=None, mode="mean"): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + logpi (torch.Tensor): log probability of the modes (B,M) + detach (bool): option whether to detach all modes but the best one + mode (string): mode of loss, sum or max + + Returns: + -log_prob (torch.Tensor): log probabilities of shape (B,) + """ + + if v is None: + v = torch.ones_like(m) + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + if avails is not None: + avails = avails.unsqueeze(1) + log_prob = log_normal(x, m, v, avails=avails) + if mode == "sum": + loglikelihood = (pi*log_prob).sum(1) + elif mode == "mean": + loglikelihood = (pi*log_prob).mean(1) + elif mode == "max": + loglikelihood = (pi*log_prob).max(1) + return loglikelihood + + +class DistanceBuffer(object): + """ class that stores the distance given x,y location + """ + def __init__(self): + self._buffer = dict() + + def __getitem__(self,key): + if key in self._buffer: + return self._buffer[key] + else: + return self.update(key) + + def update(self,key): + dis = np.linalg.norm(key) + self._buffer[key] = dis + self._buffer[-key] = dis + return dis + + +class RandomPerturbation(object): + """ + Add Gaussian noise to the trajectory + """ + def __init__(self, std: np.ndarray): + assert std.shape == (3,) and np.all(std >= 0) + self.std = std + + def perturb(self, obs): + """Given the observation object, add Gaussian noise to positions and yaws + + Args: + obs(Dict[torch.tensor]): observation dict + + Returns: + obs(Dict[torch.tensor]): perturbed observation + """ + obs = dict(obs) + target_traj = np.concatenate((obs["fut_pos"], obs["fut_yaw"]), axis=-1) + std = np.ones_like(target_traj) * self.std[None, :] + noise = np.random.normal(np.zeros_like(target_traj), std) + target_traj += noise + obs["fut_pos"] = target_traj[..., :2] + obs["fut_yaw"] = target_traj[..., :1] + return obs + + +class OrnsteinUhlenbeckPerturbation(object): + """ + Add Ornstein-Uhlenbeck noise to the trajectory + """ + def __init__(self,theta,sigma): + """ + Args: + theta (np.ndarray): converging factor of the OU process + sigma (np.ndarray): magnitude of the Gaussian noise added at each step + """ + assert theta.shape == (3,) and sigma.shape == (3,) + self.theta = theta + self.sigma = sigma + self.buffers = dict() + + def perturb(self,obs): + """Given the observation object, add Gaussian noise to positions and yaws + + Args: + obs(Dict[torch.tensor]): observation dict + + Returns: + obs(Dict[torch.tensor]): perturbed observation + """ + if isinstance(obs["fut_pos"],np.ndarray): + target_traj = np.concatenate((obs["fut_pos"], obs["fut_yaw"]), axis=-1) + bs = target_traj.shape[0] + T = target_traj.shape[-2] + if bs in self.buffers: + buffer = self.buffers[bs] + else: + buffer = [np.zeros([bs,3])] + self.buffers[bs]=buffer + while len(buffer) Dict[str, torch.Tensor]: + """ + Returns a dict of ade and fde metrics out of top-k most likely predictions. + Args: + trajs: tensor, (b, n, s, k, t, state) + log_probs: tensor, (b, n, {s,1}, k, {t,1}) + gt_trajs: tensor, (b, n, t, state) + k: int or list of ints for top k predictions + Returns: + dict with keys ['ade_top_k', 'fde_top_k'] and values of tensor, (b, n, state) + """ + assert trajs.ndim == 6 and log_probs.ndim == 5 and gt_trajs.ndim == 4 + sorted_idxs = torch.argsort(log_probs[..., [0], :], dim=-1, descending=True) + sorted_means = torch.gather(trajs, dim=-3, index=sorted_idxs.unsqueeze(-1).expand( + (-1, trajs.shape[1], -1, -1, trajs.shape[4], trajs.shape[5]))) + + errors = torch.linalg.norm(sorted_means - gt_trajs[:, :, None, None], dim=-1) # b, n, s, k, t + + if isinstance(k, int): + k = [k] + res = {} + for ki in k: + top_k_errors = errors[..., :ki, :] + + ade_k = torch.mean(top_k_errors, dim=-1) # b, n, s, k + min_ade_k = torch.min(ade_k, dim=-1).values # b, n, s + + fde_k = top_k_errors[..., -1] # b, n, s, k + min_fde_k = torch.min(fde_k, dim=-1).values # b, n, s + + res[f"ade_top_{ki}"] = min_ade_k + res[f"fde_top_{ki}"] = min_fde_k + + return res + + +## The next set of functions are replicates from the trajectron repo. + +def compute_ade_pt(predicted_trajs, gt_traj): + error = torch.linalg.norm(predicted_trajs - gt_traj, dim=-1) + ade = torch.mean(error, axis=-1) + return ade.flatten() + + +def compute_fde_pt(predicted_trajs, gt_traj): + final_error = torch.linalg.norm(predicted_trajs[..., -1, :] - gt_traj[..., -1, :], dim=-1) + return final_error.flatten() + + +def compute_nll_pt(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(gt_traj), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz.sum(dim=0), -log_p_y_xz_final.sum(dim=0) + + +def compute_min_afde_k_pt(predicted_dist, gt_traj, k): + means = predicted_dist.mus[..., :gt_traj.shape[-1]] + probs = predicted_dist.pis_cat_dist.probs + sorted_idxs = torch.argsort(probs[..., [0], :], dim=-1, descending=True) + sorted_means = torch.gather(means, dim=-2, index=sorted_idxs.unsqueeze(-1).expand((-1, -1, means.shape[2], -1, means.shape[4]))) + + errors = torch.linalg.norm(sorted_means - gt_traj.unsqueeze(-2), dim=-1) + top_k_errors = errors[..., :k] + + ade_k = torch.mean(top_k_errors, dim=-2) + min_ade_k = torch.min(ade_k, dim=-1).values + + fde_k = top_k_errors[..., -1, :] + min_fde_k = torch.min(fde_k, dim=-1).values + + return min_ade_k.flatten(), min_fde_k.flatten() + + +def compute_nll(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(torch.as_tensor(gt_traj)), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz[0].numpy(), -log_p_y_xz_final[0].numpy() + + +def compute_prediction_metrics(futures, + prediction_output_dict=None, + y_dists=None): + eval_ret = dict() + if prediction_output_dict is not None: + ade_errors = compute_ade_pt(prediction_output_dict, futures) + fde_errors = compute_fde_pt(prediction_output_dict, futures) + + eval_ret['ml_ade'] = ade_errors + eval_ret['ml_fde'] = fde_errors + + if y_dists is not None: + nll_means, nll_finals = compute_nll_pt(y_dists, futures) + min_ade_5, min_fde_5 = compute_min_afde_k_pt(y_dists, futures, k=5) + min_ade_10, min_fde_10 = compute_min_afde_k_pt(y_dists, futures, k=10) + + eval_ret['min_ade_5'] = min_ade_5 + eval_ret['min_fde_5'] = min_fde_5 + eval_ret['min_ade_10'] = min_ade_10 + eval_ret['min_fde_10'] = min_fde_10 + eval_ret['nll_mean'] = nll_means + eval_ret['nll_final'] = nll_finals + + return eval_ret + + +def scene_centric_prediction_metrics(scene_batch: SceneBatch, pred_outputs: Dict) -> Dict[str, torch.Tensor]: + + futures = scene_batch.agent_fut[:, :, :, :2].transpose(0, 1) # n, b, t, 2 + if "pred_dist_with_ego" not in pred_outputs or pred_outputs["pred_dist_with_ego"] is None: + return {} + pred_dist = pred_outputs["pred_dist_with_ego"] + + # Metrics for all agents + metrics_dict = compute_prediction_metrics(futures, prediction_output_dict=None, y_dists=pred_dist) + # Average over agents + metrics_dict = {k: v.mean(dim=0) for k, v in metrics_dict.items()} + + # Ego metrics + ego_pred_dist = GMM2D(pred_dist.log_pis[:1], pred_dist.mus[:1], pred_dist.log_sigmas[:1], pred_dist.corrs[:1]) + ego_metrics_dict = compute_prediction_metrics(futures[0], prediction_output_dict=None, y_dists=ego_pred_dist) + metrics_dict.update({"ego_"+k: v for k, v in ego_metrics_dict.items()}) + + # Move to cpu + metrics_dict = {k: v.detach().cpu() for k, v in metrics_dict.items()} + + return metrics_dict diff --git a/diffstack/utils/model_registrar.py b/diffstack/utils/model_registrar.py new file mode 100644 index 0000000..1e5855d --- /dev/null +++ b/diffstack/utils/model_registrar.py @@ -0,0 +1,88 @@ +import os +import torch +import torch.nn as nn +from pathlib import Path + + +def get_model_device(model): + return next(model.parameters()).device + + +class ModelRegistrar(nn.Module): + def __init__(self, model_dir, device): + super(ModelRegistrar, self).__init__() + self.model_dict = nn.ModuleDict() + self.model_dir = model_dir + self.device = device + + def forward(self): + raise NotImplementedError('Although ModelRegistrar is a nn.Module, it is only to store parameters.') + + def get_model(self, name, model_if_absent=None): + # 4 cases: name in self.model_dict and model_if_absent is None (OK) + # name in self.model_dict and model_if_absent is not None (OK) + # name not in self.model_dict and model_if_absent is not None (OK) + # name not in self.model_dict and model_if_absent is None (NOT OK) + + if name in self.model_dict: + return self.model_dict[name] + + elif model_if_absent is not None: + self.model_dict[name] = model_if_absent.to(self.device) + return self.model_dict[name] + + else: + raise ValueError(f'{name} was never initialized in this Registrar!') + + def get_name_match(self, name): + ret_model_list = nn.ModuleList() + for key in self.model_dict.keys(): + if name in key: + ret_model_list.append(self.model_dict[key]) + return ret_model_list + + def get_all_but_name_match(self, names): + if not isinstance(names, list) and not isinstance(names, tuple): + names = [names] + ret_model_list = nn.ModuleList() + for key in self.model_dict.keys(): + if all([name not in key for name in names]): + ret_model_list.append(self.model_dict[key]) + return ret_model_list + + def print_model_names(self): + print(self.model_dict.keys()) + + def save_models(self, curr_iter): + # Create the model directiory if it's not present. + save_path = os.path.join(self.model_dir, + 'model_registrar-%d.pt' % curr_iter) + + torch.save(self.model_dict, save_path) + + def load_models(self, iter_num): + save_path = os.path.join(self.model_dir, + 'model_registrar-%d.pt' % iter_num) + self.load_model_from_file(save_path) + + def load_model_from_file(self, file_path, except_contains=()): + file_path = str(Path(file_path).expanduser()) + print('\nLoading from ' + file_path) + + # Import error can happen here is trying to load checkpoint with old planner_cost object. + # To resolve it, one can remove the planner_cost from the checkpoint file using cleanup_checkpoint.py + # Alternatively, the old pred_metrics folder needs to be copied with environment/nuScenes_data/cost_functions.py + # Same can happend with `model` which reqires the original trajectron++ code to be in the path. + # sys.path.append('./trajectron/trajectron') + new_model_dict = torch.load(file_path, map_location=self.device) + + # Selectively update parameters + for k in new_model_dict: + if any([(substr in k) for substr in except_contains]): + print(f"Skipping {k}") + else: + self.model_dict[k] = new_model_dict[k] + + # self.model_dict = {k: v for k, v in self.model_dict.items() if substr not in k} + print('Loaded!') + print('') diff --git a/diffstack/utils/model_utils.py b/diffstack/utils/model_utils.py new file mode 100644 index 0000000..32f11e9 --- /dev/null +++ b/diffstack/utils/model_utils.py @@ -0,0 +1,772 @@ +import torch +import torch.nn.functional as F +import numpy as np +from diffstack.models.base_models import MLP +import torch.nn as nn +from diffstack.utils.geometry_utils import round_2pi, batch_proj_xysc, rel_xysc +import community as community_louvain +import networkx as nx +from scipy.sparse.csgraph import connected_components +from scipy.sparse import csr_matrix +import torch.distributions as td +import diffstack.utils.tensor_utils as TensorUtils +import math +from diffstack.models.TypeTransformer import CrossAttention + +class PED_PED_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(PED_PED_encode, self).__init__() + self.FC = MLP(10, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + deltax = x2[..., 0:2] - x1[..., 0:2] + input = torch.cat((deltax, x1[..., 2:4], x2[..., 2:4], size1, size2), dim=-1) + return self.FC(input) + + +class PED_VEH_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(PED_VEH_encode, self).__init__() + self.FC = MLP(10, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + deltax = x2[..., 0:2] - x1[..., 0:2] + veh_vel = torch.cat( + ( + torch.unsqueeze(x2[..., 2] * torch.cos(x2[..., 3]), dim=-1), + torch.unsqueeze(x2[..., 2] * torch.sin(x2[..., 3]), dim=-1), + ), + dim=-1, + ) + input = torch.cat((deltax, x1[..., 2:4], veh_vel, size1, size2), dim=-1) + return self.FC(input) + + +class VEH_PED_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(VEH_PED_encode, self).__init__() + self.FC = MLP(9, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + dx0 = x2[..., 0:2] - x1[..., 0:2] + theta = x1[..., 3] + dx = torch.cat( + ( + torch.unsqueeze( + dx0[..., 0] * torch.cos(theta) + torch.sin(theta) * dx0[..., 1], + dim=-1, + ), + torch.unsqueeze( + dx0[..., 1] * torch.cos(theta) - torch.sin(theta) * dx0[..., 0], + dim=-1, + ), + ), + dim=-1, + ) + dv = torch.cat( + ( + torch.unsqueeze( + x2[..., 2] * torch.cos(theta) + + torch.sin(theta) * x2[..., 3] + - x1[..., 2], + dim=-1, + ), + torch.unsqueeze( + x2[..., 3] * torch.cos(theta) - torch.sin(theta) * x2[..., 2], + dim=-1, + ), + ), + dim=-1, + ) + input = torch.cat( + (dx, torch.unsqueeze(x1[..., 2], dim=-1), dv, size1, size2), dim=-1 + ) + return self.FC(input) + + +class VEH_VEH_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(VEH_VEH_encode, self).__init__() + self.FC = MLP(11, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + dx0 = x2[..., 0:2] - x1[..., 0:2] + theta = x1[..., 3] + dx = torch.cat( + ( + torch.unsqueeze( + dx0[..., 0] * torch.cos(theta) + torch.sin(theta) * dx0[..., 1], + dim=-1, + ), + torch.unsqueeze( + dx0[..., 1] * torch.cos(theta) - torch.sin(theta) * dx0[..., 0], + dim=-1, + ), + ), + dim=-1, + ) + dtheta = x2[..., 3] - x1[..., 3] + dv = torch.cat( + ( + torch.unsqueeze(x2[..., 2] * torch.cos(dtheta) - x1[..., 2], dim=-1), + torch.unsqueeze(torch.sin(dtheta) * x2[..., 2], dim=-1), + ), + dim=-1, + ) + input = torch.cat( + ( + dx, + torch.unsqueeze(x1[..., 2], dim=-1), + dv, + torch.unsqueeze(torch.cos(dtheta), dim=-1), + torch.unsqueeze(torch.sin(dtheta), dim=-1), + size1, + size2, + ), + dim=-1, + ) + return self.FC(input) + + +def PED_rel_state(x, x0): + rel_x = torch.clone(x) + rel_x[..., 0:2] -= x0[..., 0:2] + return rel_x + + +def VEH_rel_state(x, x0): + rel_XY = x[..., 0:2] - x0[..., 0:2] + theta = x0[..., 3] + rel_x = torch.stack( + [ + rel_XY[..., 0] * torch.cos(theta) + rel_XY[..., 1] * torch.sin(theta), + rel_XY[..., 1] * torch.cos(theta) - rel_XY[..., 0] * torch.sin(theta), + x[..., 2], + x[..., 3] - x0[..., 3], + ], + dim=-1, + ) + rel_x[..., 3] = round_2pi(rel_x[..., 3]) + return rel_x + + +class PED_pre_encode(nn.Module): + def __init__(self, enc_dim, hidden_dim=[64], use_lane_info=False): + super(PED_pre_encode, self).__init__() + self.FC = MLP(4, enc_dim, hidden_dim) + + def forward(self, x): + return self.FC(x) + + +class VEH_pre_encode(nn.Module): + def __init__(self, enc_dim, hidden_dim=[64], use_lane_info=False): + super(VEH_pre_encode, self).__init__() + self.use_lane_info = use_lane_info + if use_lane_info: + self.FC = MLP(8, enc_dim, hidden_dim) + else: + self.FC = MLP(5, enc_dim, hidden_dim) + + def forward(self, x): + if self.use_lane_info: + input = torch.cat( + ( + x[..., 0:3], + torch.cos(x[..., 3:4]), + torch.sin(x[..., 3:4]), + x[..., 4:5], + torch.cos(x[..., 5:6]), + torch.sin(x[..., 5:6]), + ), + dim=-1, + ) + else: + input = torch.cat( + (x[..., 0:3], torch.cos(x[..., 3:]), torch.sin(x[..., 3:])), dim=-1 + ) + return self.FC(input) + +def break_graph(M, resol=1.0): + if isinstance(M, np.ndarray): + resol = resol * np.max(M) + G = nx.Graph() + for i in range(M.shape[0]): + G.add_node(i) + for i in range(M.shape[0]): + for j in range(i + 1, M.shape[0]): + if M[i, j] > 0: + G.add_edge(i, j, weight=M[i, j]) + partition = community_louvain.best_partition(G, resolution=resol) + elif isinstance(M, nx.classes.graph.Graph): + G = M + partition = community_louvain.best_partition(G, resolution=resol) + + while max(partition.values()) == 0 and resol >= 0.1: + resol = resol * 0.9 + partition = community_louvain.best_partition(G, resolution=resol) + return partition + + +def break_graph_recur(M, max_num): + n_components, labels = connected_components( + csgraph=csr_matrix(M), directed=False, return_labels=True + ) + idx = 0 + + while idx < n_components: + subset = np.where(labels == idx)[0] + if subset.shape[0] <= max_num: + idx += 1 + else: + partition = break_graph(M[np.ix_(subset, subset)]) + added_partition = 0 + for i in range(subset.shape[0]): + if partition[i] > 0: + labels[subset[i]] = n_components + partition[i] - 1 + added_partition = max(added_partition, partition[i]) + + n_components += added_partition + if added_partition == 0: + idx += 1 + + return n_components, labels + +def unpack_RNN_state(state_tuple): + # PyTorch returned LSTM states have 3 dims: + # (num_layers * num_directions, batch, hidden_size) + + state = torch.cat(state_tuple, dim=0).permute(1, 0, 2) + # Now state is (batch, 2 * num_layers * num_directions, hidden_size) + + state_size = state.size() + return torch.reshape(state, (-1, state_size[1] * state_size[2])) + +class Normal: + + def __init__(self, mu=None, logvar=None, params=None): + super().__init__() + if params is not None: + self.mu, self.logvar = torch.chunk(params, chunks=2, dim=-1) + else: + assert mu is not None + assert logvar is not None + self.mu = mu + self.logvar = logvar + self.sigma = torch.exp(0.5 * self.logvar) + + def rsample(self): + eps = torch.randn_like(self.sigma) + return self.mu + eps * self.sigma + + def sample(self): + return self.rsample() + + def kl(self, p=None): + """ compute KL(q||p) """ + if p is None: + kl = -0.5 * (1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + else: + term1 = (self.mu - p.mu) / (p.sigma + 1e-8) + term2 = self.sigma / (p.sigma + 1e-8) + kl = 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2) + return kl + + def mode(self): + return self.mu + def pseudo_sample(self,n_sample): + sigma_points = torch.stack((self.mu,self.mu-self.sigma,self.mu+self.sigma),1) + if n_sample<=1: + return sigma_points[:,:n_sample] + else: + remain_n = n_sample-3 + sigma_tiled = self.sigma.unsqueeze(1).repeat_interleave(remain_n,1) + mu_tiled = self.mu.unsqueeze(1).repeat_interleave(remain_n,1) + sample = torch.randn_like(sigma_tiled)*sigma_tiled+mu_tiled + return torch.cat([sigma_points,sample],1) + +class Categorical: + + def __init__(self, probs=None, logits=None, temp=0.01): + super().__init__() + self.logits = logits + self.temp = temp + if probs is not None: + self.probs = probs + else: + assert logits is not None + self.probs = torch.softmax(logits, dim=-1) + self.dist = td.OneHotCategorical(self.probs) + + def rsample(self,n_sample=1): + relatex_dist = td.RelaxedOneHotCategorical(self.temp, self.probs) + return relatex_dist.rsample((n_sample,)).transpose(0,-2) + + def sample(self): + return self.dist.sample() + + def pseudo_sample(self,n_sample): + D = self.probs.shape[-1] + idx = self.probs.argsort(-1,descending=True) + assert n_sample<=D + return TensorUtils.to_one_hot(idx[...,:n_sample],num_class=D) + + + + def kl(self, p=None): + """ compute KL(q||p) """ + if p is None: + p = Categorical(logits=torch.zeros_like(self.probs)) + kl = td.kl_divergence(self.dist, p.dist) + return kl + + def mode(self): + argmax = self.probs.argmax(dim=-1) + one_hot = torch.zeros_like(self.probs) + one_hot.scatter_(1, argmax.unsqueeze(1), 1) + return one_hot + + + +def initialize_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + if m.bias is not None: nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: nn.init.constant_(m.bias, 0) + +class AFMLP(nn.Module): + + def __init__(self, input_dim, hidden_dims=(128, 128), activation='tanh'): + super().__init__() + if activation == 'tanh': + self.activation = torch.tanh + elif activation == 'relu': + self.activation = torch.relu + elif activation == 'sigmoid': + self.activation = torch.sigmoid + + self.out_dim = hidden_dims[-1] + self.affine_layers = nn.ModuleList() + last_dim = input_dim + for nh in hidden_dims: + self.affine_layers.append(nn.Linear(last_dim, nh)) + last_dim = nh + + initialize_weights(self.affine_layers.modules()) + + def forward(self, x): + for affine in self.affine_layers: + x = self.activation(affine(x)) + return x + +def rotation_2d_torch(x, theta, origin=None): + if origin is None: + origin = torch.zeros(2).to(x.device).to(x.dtype) + norm_x = x - origin + norm_rot_x = torch.zeros_like(x) + norm_rot_x[..., 0] = norm_x[..., 0] * torch.cos(theta) - norm_x[..., 1] * torch.sin(theta) + norm_rot_x[..., 1] = norm_x[..., 0] * torch.sin(theta) + norm_x[..., 1] * torch.cos(theta) + rot_x = norm_rot_x + origin + return rot_x, norm_rot_x + + +class ExpParamAnnealer(nn.Module): + + def __init__(self, start, finish, rate, cur_epoch=0): + super().__init__() + self.register_buffer('start', torch.tensor(start)) + self.register_buffer('finish', torch.tensor(finish)) + self.register_buffer('rate', torch.tensor(rate)) + self.register_buffer('cur_epoch', torch.tensor(cur_epoch)) + + def step(self): + self.cur_epoch += 1 + + def set_epoch(self, epoch): + self.cur_epoch.fill_(epoch) + + def val(self): + return self.finish - (self.finish - self.start) * (self.rate ** self.cur_epoch) + +class IntegerParamAnnealer(nn.Module): + + def __init__(self, start, finish, length, cur_epoch=0): + super().__init__() + self.register_buffer('start', torch.tensor(start)) + self.register_buffer('finish', torch.tensor(finish)) + self.register_buffer('length', torch.tensor(length)) + self.register_buffer('cur_epoch', torch.tensor(cur_epoch)) + + def step(self): + self.cur_epoch += 1 + + def set_epoch(self, epoch): + self.cur_epoch.fill_(epoch) + + def val(self): + return self.finish if self.cur_epoch>=self.length else self.start+int((self.finish-self.start)*self.cur_epoch/self.length) + + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(8, channels) + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def Gaussian_RBF_conv(sigma,radius,device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + X = torch.range(-radius,radius) + Y = torch.range(-radius,radius) + XY = torch.meshgrid(X,Y) + dis_sq = XY[0]**2+XY[1]**2 + RBF = torch.exp(-dis_sq/(2*sigma**2)).to(device) + RBF = torch.nn.Parameter(RBF[None,None]/RBF.sum(),requires_grad=False) + net = nn.Conv2d(1,1,2*radius+1,bias=False,padding=radius,device=device) + net.weight = RBF + return net + +def agent2agent_edge(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # assumming x1 and x2 are state of the form [x,y,v,s,c,w,l,type_one_hot] + N1 = x1.shape[1] + N2 = x2.shape[1] + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + x1 = x1.unsqueeze(2).repeat_interleave(N2,2) + x2 = x2.unsqueeze(1).repeat_interleave(N1,1) + x1_xysc = x1[...,[0,1,3,4]] + x2_xysc = x2[...,[0,1,3,4]] + dx_xysc = rel_xysc(x1_xysc,x2_xysc)*(x1_flag.unsqueeze(2)*x2_flag.unsqueeze(1)).unsqueeze(-1) + dx_xy = (dx_xysc[...,:2]/scale).clip(min=clip[0],max=clip[1]) + + dx_xysc = torch.cat([dx_xy,dx_xysc[...,2:]],-1) + + v1 = x1[...,2:3] + v2 = x2[...,2:3] + edge = torch.cat([dx_xysc,v2*dx_xysc[...,2:3],v2*dx_xysc[...,3:4],x2[...,5:]],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def agent2lane_edge_proj(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,N,d], x2:[B,M,L*4] + x2 = x2.reshape(*x2.shape[:-1],-1,4) + N = x1.shape[1] + M = x2.shape[1] + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + dx = batch_proj_xysc(x1[:,:,None,[0,1,3,4]].repeat_interleave(M,2),x2[:,None].repeat_interleave(N,1)) + dx_xy = (dx[...,:2]/scale).clip(min=clip[0],max=clip[1]) + dx = torch.cat([dx_xy,dx[...,2:]],-1) + dx = (dx * x1_flag[:,:,None,None,None]*x2_flag[:,None,:,:,None]) + + min_idx = dx[...,0].abs().argmin(3) + min_pts = torch.gather(dx,3,min_idx[...,None,None].repeat(1,1,1,1,4)).squeeze(-2) + edge = torch.cat([min_pts,dx[:,:,:,0],dx[:,:,:,-1]],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def agent2lane_edge_per_pts(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,N,4], x2:[B,M,4] + B,N = x1.shape[:2] + M = x2.shape[1] + rel_coord = rel_xysc(x1.repeat_interleave(M,1),x2.repeat_interleave(N,0).reshape(B,N*M,4)).reshape(B,N,M,4) + + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + rel_xy = (rel_coord[...,:2]/scale).clip(min=clip[0],max=clip[1]) + rel_coord = torch.cat([rel_xy,rel_coord[...,2:]],-1) + rel_coord = rel_coord * x1_flag[:,:,None,None]*x2_flag[:,None,:,None] + + return rel_coord + +def lane2lane_edge(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,M,L*4], x2:[B,M,L*4] + x1 = x1.reshape(*x1.shape[:-1],-1,4).unsqueeze(2).repeat_interleave(x2.shape[1],2) + x2 = x2.reshape(*x2.shape[:-1],-1,4).unsqueeze(1).repeat_interleave(x1.shape[1],1) + + x1s = x1[...,0,:] + x1e = x1[...,-1,:] + x2s = x2[...,0,:] + x2e = x2[...,-1,:] + dx1 = rel_xysc(x1s,x2s) + dx2 = rel_xysc(x1s,x2e) + dx3 = rel_xysc(x1e,x2s) + dx4 = rel_xysc(x1e,x2e) + dx1 = torch.cat([(dx1[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx1[...,2:]],-1) + dx2 = torch.cat([(dx2[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx2[...,2:]],-1) + dx3 = torch.cat([(dx3[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx3[...,2:]],-1) + dx4 = torch.cat([(dx4[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx4[...,2:]],-1) + edge = torch.cat([dx1,dx2,dx3,dx4],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def edge_as_aux1(x1,x2): + # when the edge encoding is passed in via x1 + # x1: [B,N1,N2*d], x2:[B,N2,_] + B,N1 = x1.shape[:2] + N2 = x2.shape[1] + return x1.reshape(B,N1,N2,-1) + +def edge_as_aux2(x1,x2): + # when the edge encoding is passed in via x1 + # x1: [B,N1,_], x2:[B,N2,N1*d] + B,N1 = x1.shape[:2] + N2 = x2.shape[1] + return x2.reshape(B,N2,N1,-1).transpose(1,2) + +class Agent_emb(nn.Module): + def __init__(self, raw_dim,n_embd): + super(Agent_emb, self).__init__() + self.raw_dim = raw_dim + self.n_embd = n_embd + self.FC = nn.Linear(raw_dim, n_embd) + + def forward(self, input): + if input.size(-1) 0: + col_loss += torch.sum(torch.sigmoid(-dis*4) * scale*type_mask[et][:,None,:,None], dim=[2,3])/T + + return col_loss + + +def get_drivable_area_loss( + ego_trajectories, raster_from_agent, dis_map, ego_extents +): + """Cost for road departure.""" + with torch.no_grad(): + + lane_flags = rasterized_ROI_align( + dis_map, + ego_trajectories[..., :2], + ego_trajectories[..., 2:], + raster_from_agent, + torch.ones(*ego_trajectories.shape[:3] + ).to(ego_trajectories.device), + ego_extents.unsqueeze(1).repeat(1, ego_trajectories.shape[1], 1), + 1, + ).squeeze(-1) + return lane_flags.max(dim=-1)[0] + +def get_lane_loss_simple(ego_trajectories, raster_from_agent, dis_map): + h,w = dis_map.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points(ego_trajectories[...,:2],raster_from_agent) + raster_xy[...,0] = raster_xy[...,0].clip(0,w-1e-5) + raster_xy[...,1] = raster_xy[...,1].clip(0,h-1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = (raster_xy[...,1]*w+raster_xy[...,0]) + raster_xy_flat = raster_xy_flat.flatten() + lane_loss = (dis_map.flatten()[raster_xy_flat]).reshape(*raster_xy.shape[:2]) + return lane_loss.max(dim=-1)[0] + +def get_terminal_likelihood_reward( + ego_trajectories, raster_from_agent, log_likelihood +): + """Cost for road departure.""" + + log_likelihood = (log_likelihood-log_likelihood.mean())/log_likelihood.std() + h,w = log_likelihood.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points(ego_trajectories[...,-1,:2],raster_from_agent) + raster_xy[...,0] = raster_xy[...,0].clip(0,w-1e-5) + raster_xy[...,1] = raster_xy[...,1].clip(0,h-1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = (raster_xy[...,1]*w+raster_xy[...,0]) + + ll_reward = log_likelihood.flatten()[raster_xy_flat] + return ll_reward + +def get_progress_reward(ego_trajectories,d_sat = 10): + dis = torch.linalg.norm(ego_trajectories[...,-1,:2]-ego_trajectories[...,0,:2],dim=-1) + return 2/np.pi*torch.atan(dis/d_sat) + + +def get_total_distance(ego_trajectories): + """Reward that incentivizes progress.""" + # Assume format [..., T, 3] + assert ego_trajectories.shape[-1] == 3 + diff = ego_trajectories[..., 1:, :] - ego_trajectories[..., :-1, :] + dist = torch.norm(diff[..., :2], dim=-1) + total_dist = torch.sum(dist, dim=-1) + return total_dist + + +def ego_sample_planning( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + raster_from_agent, + dis_map, + weights, + log_likelihood=None, + col_funcs=None, +): + """A basic cost function for prediction-and-planning""" + col_loss = get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + col_funcs, + ) + lane_loss = get_drivable_area_loss( + ego_trajectories, raster_from_agent, dis_map, ego_extents + ) + progress = get_total_distance(ego_trajectories) + + log_likelihood = 0 if log_likelihood is None else log_likelihood + if log_likelihood.ndim==3: + log_likelihood = get_terminal_likelihood_reward(ego_trajectories, raster_from_agent, log_likelihood) + + total_score = ( + + weights["likelihood_weight"] * log_likelihood + + weights["progress_weight"] * progress + - weights["collision_weight"] * col_loss + - weights["lane_weight"] * lane_loss + ) + + + return torch.argmax(total_score, dim=1) + + +class TreeMotionPolicy(object): + """ A trajectory tree policy as the result of contingency planning + + """ + def __init__(self,stage,num_frames_per_stage,ego_root,scenario_root,cost_to_go,leaf_idx,curr_node): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_root = ego_root + self.scenario_root = scenario_root + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node= curr_node + + def identify_branch(self,ego_node,scene_traj): + + assert scene_traj.shape[-2]=scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + ego_leaf_index = self.leaf_idx[ego_node] + while remain_traj.shape[1]>0: + seg_length = min(remain_traj.shape[-2],self.num_frames_per_stage) + dis = [torch.linalg.norm(child.traj[ego_leaf_index,:,:seg_length,:2]-remain_traj[:,:seg_length,:2],dim=-1).sum().item()\ + for child in curr_scenario_node.children] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[...,seg_length:,:] + remain_num_frames = curr_scenario_node.traj.shape[-2]-seg_length + if curr_scenario_node.stage>=self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self,scene_traj,horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0]-1-T + assert remain_num_frames>-self.num_frames_per_stage + if remain_num_frames<=0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node,scene_traj) + Q = [self.cost_to_go[(child,curr_scenario_node)] for child in self.curr_node.children] + idx = torch.argmin(torch.tensor(Q)).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames+=self.curr_node.traj.shape[0] + traj = self.curr_node.traj[-remain_num_frames:,TRAJ_INDEX] + if not self.curr_node.isleaf(): + traj = torch.cat((traj,self.curr_node.children[0].traj[:,TRAJ_INDEX]),-2) + if traj.shape[0]>=horizon: + return traj[:horizon] + else: + traj_patched = torch.cat((traj,traj[-1].tile(horizon-traj.shape[0],1))) + return traj_patched + +class VectorizedTreeMotionPolicy(TreeMotionPolicy): + """ A vectorized trajectory tree policy as the result of contingency planning + + """ + def __init__(self,stage,num_frames_per_stage,ego_tree,children_indices,scenario_tree,cost_to_go,leaf_idx,curr_node): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_tree = ego_tree + self.ego_root = ego_tree[0][0] + self.children_indices = children_indices + self.scenario_tree = scenario_tree + self.scenario_root = scenario_tree[0][0] + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node= curr_node + + def identify_branch(self,ego_node,scene_traj): + + assert scene_traj.shape[-2]=scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + + stage = ego_node.depth + ego_stage_index = self.ego_tree[stage].index(ego_node) + ego_leaf_index = self.leaf_idx[stage][ego_stage_index].item() + while remain_traj.shape[1]>0: + seg_length = min(remain_traj.shape[-2],self.num_frames_per_stage) + dis = [torch.linalg.norm(child.traj[ego_leaf_index,:,:seg_length,:2]-remain_traj[:,:seg_length,:2],dim=-1).sum().item()\ + for child in curr_scenario_node.children] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[...,seg_length:,:] + remain_num_frames = curr_scenario_node.traj.shape[-2]-seg_length + if curr_scenario_node.stage>=self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self,scene_traj,horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0]-1-T + assert remain_num_frames>-self.num_frames_per_stage + if remain_num_frames<=0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node,scene_traj) + assert curr_scenario_node.depth==self.curr_node.depth + stage = self.curr_node.depth + scene_node_idx = self.scenario_tree[stage].index(curr_scenario_node) + curr_node_idx = self.ego_tree[stage].index(self.curr_node) + Q = self.cost_to_go[stage][scene_node_idx,self.children_indices[stage][curr_node_idx]] + idx = torch.argmin(Q).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames+=self.curr_node.traj.shape[0] + + traj = self.curr_node.traj[-remain_num_frames:,TRAJ_INDEX] + if not self.curr_node.isleaf(): + traj = torch.cat((traj,self.curr_node.children[0].traj[:,TRAJ_INDEX]),-2) + if traj.shape[0]>=horizon: + return traj[:horizon] + else: + traj_patched = torch.cat((traj,traj[-1].tile(horizon-traj.shape[0],1))) + return traj_patched + + +def tiled_to_tree(total_traj,prob,num_stage,num_frames_per_stage,M): + """Turning a trajectory tree in tiled form to a tree data structure + + Args: + total_traj (torch.tensor or np.ndarray): tiled trajectory tree + prob (torch.tensor or np.ndarray): probability of the modes + num_stage (int): number of layers of the tree + num_frames_per_stage (int): number of time frames per layer + M (int): branching factor + + Returns: + nodes (dict[int:List(AgentTrajTree)]): all branches of the trajectory tree nodes indexed by layer + """ + + # total_traj = TensorUtils.reshape_dimensions_single(total_traj,2,3,[M]*num_stage) + x0 = AgentTrajTree(None, None, 0) + nodes = defaultdict(lambda:list()) + nodes[0].append(x0) + for t in range(num_stage): + interval = M**(num_stage-t-1) + tiled_traj = total_traj[...,::interval,:,t*num_frames_per_stage:(t+1)*num_frames_per_stage,:] + for i in range(M**(t+1)): + parent_idx = int(i/M) + p = prob[:,i*interval:(i+1)*interval].sum(-1) + node = AgentTrajTree(tiled_traj[:,i], nodes[t][parent_idx], t + 1, prob=p) + nodes[t+1].append(node) + return nodes + + +def contingency_planning(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + device = agent_traj.device + leaf_idx = defaultdict(lambda:list()) + for stage in range(num_stage,-1,-1): + for node in ego_tree[stage]: + if node.isleaf(): + leaf_idx[node]=[ego_tree[stage].index(node)] + else: + leaf_idx[node] = [] + for child in node.children: + leaf_idx[node] = leaf_idx[node]+leaf_idx[child] + + + V = dict() + L = dict() + Q = dict() + scenario_tree = tiled_to_tree(agent_traj,mode_prob,num_stage,num_frames_per_stage,M) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + for stage in range(num_stage,0,-1): + if stage==0: + total_loss = torch.zeros([1,1],device=device) + else: + ego_nodes = ego_tree[stage] + indices = [leaf_idx[node][0] for node in ego_nodes] + ego_traj = [node.traj[:,TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj,0) + agent_nodes = scenario_tree[stage] + agent_traj = [node.traj[indices] for node in agent_nodes] + agent_traj = torch.stack(agent_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes),1,1,1) + col_loss = get_collision_loss(ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes),1), + agent_extents.tile(len(agent_nodes),1,1), + agent_types.tile(len(agent_nodes),1), + col_funcs, + ) + + + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[1],device=device).unsqueeze(0)*pert_std + if log_likelihood is not None and stage==num_stage: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + + for i in range(len(ego_nodes)): + for j in range(len(agent_nodes)): + L[(ego_nodes[i],agent_nodes[j])] = total_loss[j,i] + if stage==num_stage: + V[(ego_nodes[i],agent_nodes[j])] = float(total_loss[j,i]) + else: + children_cost_to_go = [Q[(child,agent_nodes[j])] for child in ego_nodes[i].children] + V[(ego_nodes[i],agent_nodes[j])] = float(total_loss[j,i])+min(children_cost_to_go) + + if stage>0: + for agent_node in scenario_tree[stage-1]: + cost_i = [] + prob_i = [] + for child in agent_node.children: + cost_i.append(V[ego_nodes[i],child]) + prob_i.append(child.prob[leaf_idx[ego_nodes[i]]].sum()) + cost_i = torch.tensor(cost_i,device=device) + prob_i = torch.stack(prob_i) + Q[(ego_nodes[i],agent_node)] = float((cost_i*prob_i).sum()/prob_i.sum()) + Q_root = [Q[(child,scenario_root)] for child in ego_root.children] + idx = torch.argmin(torch.tensor(Q_root)).item() + optimal_node = ego_root.children[idx] + motion_policy = TreeMotionPolicy(num_stage, + num_frames_per_stage, + ego_root, + scenario_root, + Q, + leaf_idx, + optimal_node) + motion_policy.get_plan(None,num_stage*num_frames_per_stage) + return motion_policy + +def contingency_planning_parallel(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + device=agent_traj.device + children_indices = TrajTree.get_children_index_torch(ego_tree) + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + + leaf_idx = {num_stage:torch.arange(len(ego_tree[num_stage]),device=device)} + stage_prob = {num_stage:mode_prob.T} + for stage in range(num_stage-1,-1,-1): + leaf_idx[stage] = leaf_idx[stage+1][children_indices[stage][:,0]] + prob_next = stage_prob[stage+1] + stage_prob[stage] = prob_next.reshape(-1,M,prob_next.shape[-1])[:,:,children_indices[stage][:,0]].sum(1) + + V = dict() + L = dict() + Q = dict() + + + scenario_tree = tiled_to_tree(agent_traj,mode_prob,num_stage,num_frames_per_stage,M) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + + + + for stage in range(num_stage,-1,-1): + if stage==0: + total_loss = torch.zeros([1,1],device=device) + else: + #calculate stage cost + ego_nodes = ego_tree[stage] + ego_traj = [node.traj[:,TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj,0) + agent_nodes = scenario_tree[stage] + + agent_traj = [node.traj[leaf_idx[stage]] for node in agent_nodes] + + agent_traj = torch.stack(agent_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes),1,1,1) + col_loss = get_collision_loss(ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes),1), + agent_extents.tile(len(agent_nodes),1,1), + agent_types.tile(len(agent_nodes),1), + col_funcs, + ) + + + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[1],device=device).unsqueeze(0)*pert_std + if log_likelihood is not None and stage==num_stage: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + L[stage] = total_loss + if stage==num_stage: + V[stage] = total_loss + else: + children_idx = children_indices[stage] + # add the last Q value as inf since empty children index are padded with -1 + Q_prime = torch.cat((Q[stage],torch.full([Q[stage].shape[0],1],np.inf,device=device)),1) + Q_by_node = Q_prime[:,children_idx] + V[stage] = total_loss+Q_by_node.min(dim=-1)[0] + + if stage>0: + children_V = V[stage] + children_V = children_V.reshape(-1,M,children_V.shape[-1]) + prob = stage_prob[stage] + prob = prob.reshape(-1,M,prob.shape[-1]) + prob_normalized = prob/prob.sum(dim=1,keepdim=True) + Q[stage-1] = (children_V*prob_normalized).sum(dim=1) + + idx = Q[0].argmin().item() + + motion_policy = VectorizedTreeMotionPolicy(num_stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + Q, + leaf_idx, + ego_root.children[idx]) + return motion_policy + +def one_shot_planning(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None, + strategy="all"): + + """Alternative of contingency planning, try to avoid all predicted trajectories + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + assert strategy=="all" or strategy=="maximum" + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + + ego_traj = [node.total_traj[1:,TRAJ_INDEX] for node in ego_tree[num_stage]] + ego_traj = torch.stack(ego_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(1).repeat_interleave(agent_traj.shape[1],1) + Ne = ego_traj.shape[0] + if strategy=="maximum": + idx = mode_prob.argmax(dim=1) + idx = idx.reshape(Ne,*[1]*(agent_traj.ndim-1)) + agent_traj = agent_traj.take_along_dim(idx,1) + col_loss = get_collision_loss(ego_traj_tiled,agent_traj,ego_extents.tile(Ne,1),agent_extents.tile(Ne,1,1),agent_types.tile(Ne,1),col_funcs) + col_loss = col_loss.max(dim=1)[0] + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)).squeeze(0) + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[0],device=total_loss.device)*pert_std + if log_likelihood is not None: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + idx = total_loss.argmin() + return ego_traj[idx] + + +def obtain_ref(line, x, v, N, dt): + """obtain desired trajectory for the MPC controller + + Args: + line (np.ndarray): centerline of the lane [n, 3] + x (np.ndarray): position of the vehicle + v (np.ndarray): desired velocity + N (int): number of time steps + dt (float): time step + + Returns: + refx (np.ndarray): desired trajectory [N,3] + """ + line_length = line.shape[0] + delta_x = line[..., 0:2] - np.repeat(x[..., np.newaxis, 0:2], line_length, axis=-2) + dis = np.linalg.norm(delta_x, axis=-1) + idx = np.argmin(dis, axis=-1) + line_min = line[idx] + dx = x[0] - line_min[0] + dy = x[1] - line_min[1] + delta_y = -dx * np.sin(line_min[2]) + dy * np.cos(line_min[2]) + delta_x = dx * np.cos(line_min[2]) + dy * np.sin(line_min[2]) + refx0 = np.array( + [ + line_min[0] + delta_x * np.cos(line_min[2]), + line_min[1] + delta_x * np.sin(line_min[2]), + line_min[2], + ] + ) + s = [np.linalg.norm(line[idx + 1, 0:2] - refx0[0:2])] + for i in range(idx + 2, line_length): + s.append(s[-1] + np.linalg.norm(line[i, 0:2] - line[i - 1, 0:2])) + f = interp1d( + np.array(s), + line[idx + 1 :], + kind="linear", + axis=0, + copy=True, + bounds_error=False, + fill_value="extrapolate", + assume_sorted=True, + ) + s1 = v * np.arange(1, N + 1) * dt + refx = f(s1) + + return refx + +def Unicycle_braking_traj(x0,dt,T,brake_acc): + """generate braking trajectory for the unicycle model + + Args: + x0 (np.ndarray): initial state + dt (float): time step + T (int): total time step + brake_acc (float): deceleration + + Returns: + np.ndarray: braking trajectory + """ + if isinstance(x0,np.ndarray): + x = np.zeros((T+1,4)) + x[0] = x0 + x[:,3]=x0[3] + x[:,2] = np.clip(x0[2]+dt*np.arange(T+1)*brake_acc,a_min=0,a_max=np.inf) + x[1:,:2] = x[0:1,:2] + np.cumsum(x[:-1,2])[:,None]*dt*np.array([[np.cos(x0[3]),np.sin(x0[3])]]) + + elif isinstance(x0,torch.Tensor): + x = torch.zeros((T+1,4)) + x[0] = x0 + x[:,3]=x0[3] + x[:,2] = torch.clip(x0[2]+dt*torch.arange(T+1)*brake_acc,min=0,max=np.inf) + x[1:,:2] = x[0:1,:2] + torch.cumsum(x[:-1,2],0)[:,None]*dt*torch.tensor([[torch.cos(x0[3]),torch.sin(x0[3])]]) + return x[1:] + + +def test(): + x0=torch.tensor([4,5,7.0,0]) + dt=0.1 + T=30 + brake_acc=-5.0 + x=Unicycle_braking_traj(x0,dt,T,brake_acc) + print(x) + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/diffstack/utils/pred_utils.py b/diffstack/utils/pred_utils.py new file mode 100644 index 0000000..5a38a5d --- /dev/null +++ b/diffstack/utils/pred_utils.py @@ -0,0 +1,105 @@ +import os +import random +import torch +import time +import numpy as np + +from scipy.optimize import minimize +from collections import defaultdict +from typing import Dict, Union, Tuple, Any, Optional, Iterable +from trajdata.data_structures.batch import PadDirection, SceneBatch +from diffstack.utils.utils import batch_select + + +def compute_ade_pt(predicted_trajs, gt_traj): + error = torch.linalg.norm(predicted_trajs - gt_traj, dim=-1) + ade = torch.mean(error, axis=-1) + return ade.flatten() + + +def compute_ade(predicted_trajs, gt_traj): + error = np.linalg.norm(predicted_trajs - gt_traj, axis=-1) + ade = np.mean(error, axis=-1) + return ade.flatten() + + +def compute_fde_pt(predicted_trajs, gt_traj): + final_error = torch.linalg.norm(predicted_trajs[:, :, -1] - gt_traj[:, -1], dim=-1) + return final_error.flatten() + + +def compute_fde(predicted_trajs, gt_traj): + final_error = np.linalg.norm(predicted_trajs[:, :, -1] - gt_traj[-1], axis=-1) + return final_error.flatten() + + +def compute_nll_pt(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(gt_traj), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz[0], -log_p_y_xz_final[0] + + +def compute_nll(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(torch.as_tensor(gt_traj)), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz[0].numpy(), -log_p_y_xz_final[0].numpy() + + +def compute_prediction_metrics(prediction_output_dict, + futures, + y_dists=None, + keep_indices=None): + ade_errors = compute_ade_pt(prediction_output_dict, futures) + fde_errors = compute_fde_pt(prediction_output_dict, futures) + if y_dists: + nll_means, nll_finals = compute_nll_pt(y_dists, futures) + + if keep_indices is not None: + return {'ade': ade_errors[keep_indices], + 'fde': fde_errors[keep_indices], + 'nll_mean': nll_means[keep_indices], + 'nll_final': nll_finals[keep_indices]} + else: + return {'ade': ade_errors, + 'fde': fde_errors, + 'nll_mean': nll_means, + 'nll_final': nll_finals} + + +def split_predicted_agent_extents( + batch: SceneBatch, + num_dist_agents: int = 1, + num_single_agents: Optional[int] = None, + max_num_agents: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get extents from scene batch. + + Returns three tensors: + - ego_extents + - pred_agent_extents + - ml_agent_extents + """ + # TODO it should be the responsibility of the predictor to associate agent to its role. + # Here we just assume a fixed order + robot_ind = batch.extras["robot_ind"] + pred_ind = batch.extras["pred_agent_ind"] + assert (robot_ind == 0).all() and (pred_ind == 1).all(), "Agent roles are assumed to be hardcoded" + assert batch.history_pad_dir == PadDirection.AFTER + if num_single_agents is None: + if max_num_agents is None: + raise ValueError("Must specify either num_ml_agents or max_num_agents") + num_single_agents = max_num_agents + 1 - num_dist_agents + + agent_extent = batch_select(batch.agent_hist_extent, index=batch.agent_hist_len-1, batch_dims=2) # b, N, t, (length, width) + ego_extent = agent_extent[:, 0] + dist_extents = agent_extent[:, 1:1+num_dist_agents] + dist_extents = torch.nn.functional.pad(dist_extents, (0, 0, 0, num_dist_agents-dist_extents.shape[1]), 'constant', torch.nan) + + single_extents = agent_extent[:, 1+num_dist_agents:(1+num_dist_agents+num_single_agents)] + single_extents = torch.nn.functional.pad(single_extents, (0, 0, 0, num_single_agents-single_extents.shape[1]), 'constant', torch.nan) + # ml_extents = agent_extent[:, 2:(2+MAX_PLAN_NEIGHBORS+1)] + # ml_extents = torch.nn.functional.pad(ml_extents, (0, 0, 0, MAX_PLAN_NEIGHBORS+1-ml_extents.shape[1]), 'constant', torch.nan) + + return ego_extent, dist_extents, single_extents diff --git a/diffstack/utils/rollout_logger.py b/diffstack/utils/rollout_logger.py new file mode 100644 index 0000000..c2665b5 --- /dev/null +++ b/diffstack/utils/rollout_logger.py @@ -0,0 +1,226 @@ +from collections import defaultdict +import numpy as np +from copy import deepcopy + +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.policies.common import RolloutAction + +from torch.nn.utils.rnn import pad_sequence +class RolloutLogger(object): + """Log trajectories and other essential info during rollout for visualization and evaluation""" + def __init__(self, obs_keys=None): + if obs_keys is None: + obs_keys = dict() + self._obs_keys = obs_keys + self._scene_indices = None + self._agent_id_per_scene = dict() + self._agent_data_by_scene = dict() + self._scene_ts = defaultdict(lambda:0) + + def _combine_obs(self, obs): + combined = dict() + excluded_keys = ["extras"] + if "ego" in obs and obs["ego"] is not None: + combined.update(obs["ego"]) + if "agents" in obs and obs["agents"] is not None: + for k in obs["agents"].keys(): + if k in combined and k not in excluded_keys: + if obs["agents"][k] is not None: + if combined[k] is not None: + combined[k] = np.concatenate((combined[k], obs["agents"][k]), axis=0) + else: + combined[k] = obs["agents"][k] + else: + combined[k] = obs["agents"][k] + return combined + + def _combine_action(self, action: RolloutAction): + combined = dict(action=dict()) + if action.has_ego and not action.has_agents: + combined["action"] = action.ego.to_dict() + if action.ego_info is not None and "action_samples" in action.ego_info: + combined["action_samples"] = action.ego_info["action_samples"] + return combined + + elif action.has_agents and not action.has_ego: + combined["action"] = action.agents.to_dict() + if action.agents_info is not None and "action_samples" in action.agents_info: + combined["action_samples"] = action.agents_info["action_samples"] + return combined + elif action.has_agents and action.has_ego: + Nego = action.ego.positions.shape[0] + Nagents = action.agents.positions.shape[0] + combined["action"] = dict() + agents_action = action.agents.to_dict() + ego_action = action.ego.to_dict() + for k in agents_action: + if k in ego_action: + combined["action"][k] = np.concatenate((ego_action[k], agents_action[k]), axis=0) + if action.agents_info is not None and action.ego_info is not None: + if "action_samples" in action.ego_info: + ego_samples = action.ego_info["action_samples"] + else: + ego_samples = None + if "action_samples" in action.agents_info: + agents_samples = action.agents_info["action_samples"] + else: + agents_samples = None + if ego_samples is not None and agents_samples is None: + combined["action_samples"] = dict() + for k in ego_samples: + pad_k = np.zeros([Nagents,*ego_samples[k].shape[1:]]) + combined["action_samples"][k]=np.concatenate((ego_samples[k],pad_k),0) + elif ego_samples is None and agents_samples is not None: + combined["action_samples"] = dict() + for k in agents_samples: + pad_k = np.zeros([Nego,*agents_samples[k].shape[1:]]) + combined["action_samples"][k]=np.concatenate((pad_k,agents_samples[k]),0) + elif ego_samples is not None and agents_samples is not None: + combined["action_samples"] = dict() + for k in ego_samples: + if k in agents_samples: + if ego_samples[k].shape[1]>agents_samples[k].shape[1]: + pad_k = np.zeros([Nagents,ego_samples[k].shape[1]-agents_samples[k].shape[1],*agents_samples[k].shape[2:]]) + agents_samples[k]=np.concatenate((agents_samples[k],pad_k),1) + elif ego_samples[k].shape[1]0: + default_val = list(self._agent_data_by_scene[si][k][ti].values())[0] + ti_k = list() + for ts in range(self._scene_ts[si]): + ti_k.append(self._agent_data_by_scene[si][k][ti][ts] if ts in self._agent_data_by_scene[si][k][ti] else np.ones_like(default_val)*np.nan) + default_val = ti_k[-1] + if not all(elem.shape==ti_k[0].shape for elem in ti_k): + # requires padding + if np.issubdtype(ti_k[0].dtype,np.floating): + padding_value = np.nan + else: + padding_value = 0 + ti_k = [x[0] for x in ti_k] + ti_k_torch = TensorUtils.to_tensor(ti_k,ignore_if_unspecified=True) + + ti_k_padded = pad_sequence(ti_k_torch,padding_value=padding_value,batch_first=True) + serialized[si][k].append(TensorUtils.to_numpy(ti_k_padded)[np.newaxis,:]) + else: + if ti_k[0].ndim==0: + serialized[si][k].append(np.array(ti_k)[np.newaxis,:]) + else: + serialized[si][k].append(np.concatenate(ti_k,axis=0)[np.newaxis,:]) + else: + serialized[si][k].append(np.zeros_like(serialized[si][k][-1])) + if not all(elem.shape==serialized[si][k][0].shape for elem in serialized[si][k]): + # requires padding + if np.issubdtype(serialized[si][k][0][0].dtype,np.floating): + padding_value = np.nan + else: + padding_value = 0 + axes=[1,0]+np.arange(2,serialized[si][k][0].ndim-1).tolist() + mk_transpose = [np.transpose(x[0],axes) for x in serialized[si][k]] + mk_torch = TensorUtils.to_tensor(mk_transpose,ignore_if_unspecified=True) + mk_padded = pad_sequence(mk_torch,padding_value=padding_value) + mk = TensorUtils.to_numpy(mk_padded) + axes=[1,2,0]+np.arange(3,mk.ndim).tolist() + serialized[si][k]=np.transpose(mk,axes) + else: + serialized[si][k] = np.concatenate(serialized[si][k],axis=0) + + + + self._serialized_scene_buffer = serialized + return deepcopy(self._serialized_scene_buffer) + + def get_trajectory(self): + """Get per-scene rollout trajectory in the world coordinate system""" + buffer = self.get_serialized_scene_buffer() + traj = dict() + for si in buffer: + traj[si] = dict( + positions=buffer[si]["centroid"], + yaws=buffer[si]["world_yaw"] + ) + return traj + + def get_track_id(self): + return deepcopy(self._agent_id_per_scene) + + def get_stats(self): + # TODO + raise NotImplementedError() + + def log_step(self, obs, action: RolloutAction): + combined_obs = self._combine_obs(obs) + combined_action = self._combine_action(action) + assert combined_obs["scene_index"].shape[0] == combined_action["action"]["positions"].shape[0] + self._maybe_initialize(combined_obs) + self._append_buffer(combined_obs, combined_action) + for si in np.unique(combined_obs["scene_index"]): + self._scene_ts[si]+=1 + del combined_obs diff --git a/diffstack/utils/sys_utils.py b/diffstack/utils/sys_utils.py new file mode 100644 index 0000000..370fa6a --- /dev/null +++ b/diffstack/utils/sys_utils.py @@ -0,0 +1,12 @@ +import os + + +def delete_files_in_directory(directory_path): + try: + files = os.listdir(directory_path) + for file in files: + file_path = os.path.join(directory_path, file) + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + print("Error occurred while deleting files.") diff --git a/diffstack/utils/tensor_utils.py b/diffstack/utils/tensor_utils.py new file mode 100644 index 0000000..7ea4bff --- /dev/null +++ b/diffstack/utils/tensor_utils.py @@ -0,0 +1,1189 @@ +""" +A collection of utilities for working with nested tensor structures consisting +of numpy arrays and torch tensors. +""" +import collections +from tracemalloc import start +import numpy as np +import torch +import torch.nn as nn +import sys + + +def recursive_dict_list_tuple_apply(x, type_func_dict, ignore_if_unspecified=False): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert list not in type_func_dict + assert tuple not in type_func_dict + assert dict not in type_func_dict + assert nn.ParameterDict not in type_func_dict + assert nn.ParameterList not in type_func_dict + + if isinstance(x, (dict, collections.OrderedDict, nn.ParameterDict)): + new_x = ( + collections.OrderedDict() + if isinstance(x, collections.OrderedDict) + else dict() + ) + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply( + v, type_func_dict, ignore_if_unspecified + ) + return new_x + elif isinstance(x, (list, tuple, nn.ParameterList)): + ret = [ + recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) + for v in x + ] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + if ignore_if_unspecified: + return x + else: + raise NotImplementedError("Cannot handle data type %s" % str(type(x))) + + +def map_tensor(x, func): + """ + Apply function @func to torch.Tensor objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each tensor + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: func, + type(None): lambda x: x, + }, + ) + + +def map_ndarray(x, func): + """ + Apply function @func to np.ndarray objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + np.ndarray: func, + str: lambda x: x, + type(None): lambda x: x, + }, + ) + + +def map_tensor_ndarray(x, tensor_func, ndarray_func): + """ + Apply function @tensor_func to torch.Tensor objects and @ndarray_func to + np.ndarray objects in a nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + tensor_func (function): function to apply to each tensor + ndarray_Func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: tensor_func, + np.ndarray: ndarray_func, + type(None): lambda x: x, + }, + ) + + +def clone(x): + """ + Clones all torch tensors and numpy arrays in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.clone(), + np.ndarray: lambda x: x.copy(), + type(None): lambda x: x, + }, + ) + + +def detach(x): + """ + Detaches all torch tensors in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.detach(), + np.ndarray: lambda x: x, + type(None): lambda x: x, + }, + ignore_if_unspecified=True, + ) + + +def to_batch(x): + """ + Introduces a leading batch dimension of 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[None, ...], + np.ndarray: lambda x: x[None, ...], + type(None): lambda x: x, + }, + ) + + +def to_sequence(x): + """ + Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, None, ...], + np.ndarray: lambda x: x[:, None, ...], + type(None): lambda x: x, + }, + ) + + +def index_at_time(x, ind): + """ + Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in + nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ind (int): index + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, ind, ...], + np.ndarray: lambda x: x[:, ind, ...], + type(None): lambda x: x, + }, + ) + + +def unsqueeze(x, dim): + """ + Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.unsqueeze(dim=dim), + np.ndarray: lambda x: np.expand_dims(x, axis=dim), + type(None): lambda x: x, + }, + ) + + +def squeeze(x, dim): + """ + Reduce dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.squeeze(dim=dim), + np.ndarray: lambda x: np.squeeze(x, axis=dim), + type(None): lambda x: x, + }, + ) + + +def contiguous(x): + """ + Makes all torch tensors and numpy arrays contiguous in nested dictionary or + list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.contiguous(), + np.ndarray: lambda x: np.ascontiguousarray(x), + type(None): lambda x: x, + }, + ) + + +def to_device(x, device, ignore_if_unspecified=False): + """ + Sends all torch tensors in nested dictionary or list or tuple to device + @device, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, d=device: x.to(d), + str: lambda x: x, + type(None): lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_tensor(x, ignore_if_unspecified=False): + """ + Converts all numpy arrays in nested dictionary or list or tuple to + torch tensors (and leaves existing torch Tensors as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x, + str: lambda x: x, + np.ndarray: lambda x: torch.from_numpy(x), + type(None): lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_numpy(x, ignore_if_unspecified=False): + """ + Converts all torch tensors in nested dictionary or list or tuple to + numpy (and leaves existing numpy arrays as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy() + else: + return tensor.detach().numpy() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x, + type(None): lambda x: x, + str: lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_list(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to a list, and returns a new nested structure. Useful for + json encoding. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy().tolist() + else: + return tensor.detach().numpy().tolist() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x.tolist(), + type(None): lambda x: x, + }, + ) + + +def to_float(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to float type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.float(), + np.ndarray: lambda x: x.astype(np.float32), + type(None): lambda x: x, + }, + ) + + +def to_uint8(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to uint8 type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.byte(), + np.ndarray: lambda x: x.astype(np.uint8), + type(None): lambda x: x, + }, + ) + + +def to_torch(x, device, ignore_if_unspecified=False): + """ + Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to + torch tensors on device @device and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return to_device( + to_tensor(x, ignore_if_unspecified=ignore_if_unspecified), + device, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_one_hot_single(tensor, num_class): + """ + Convert tensor to one-hot representation, assuming a certain number of total class labels. + + Args: + tensor (torch.Tensor): tensor containing integer labels + num_class (int): number of classes + + Returns: + x (torch.Tensor): tensor containing one-hot representation of labels + """ + x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) + x.scatter_(-1, tensor.unsqueeze(-1), 1) + return x + + +def to_one_hot(tensor, num_class): + """ + Convert all tensors in nested dictionary or list or tuple to one-hot representation, + assuming a certain number of total class labels. + + Args: + tensor (dict or list or tuple): a possibly nested dictionary or list or tuple + num_class (int): number of classes + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) + + +def flatten_single(x, begin_axis=1): + """ + Flatten a tensor in all dimensions from @begin_axis onwards. + + Args: + x (torch.Tensor): tensor to flatten + begin_axis (int): which axis to flatten from + + Returns: + y (torch.Tensor): flattened tensor + """ + fixed_size = x.size()[:begin_axis] + _s = list(fixed_size) + [-1] + return x.reshape(*_s) + + +def flatten(x, begin_axis=1): + """ + Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): which axis to flatten from + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), + }, + ) + + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert begin_axis < end_axis + assert begin_axis >= 0 + assert end_axis <= len(x.shape) + assert isinstance(target_dims, (tuple, list)) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i >= end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def reshape_dimensions(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions for all tensors in nested dictionary or list or tuple + to a target dimension. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension (excluding the dimension) + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + type(None): lambda x: x, + }, + ) + + +def join_dimensions(x, begin_axis, end_axis): + """ + Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for + all tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension (excluding the dimension) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + type(None): lambda x: x, + }, + ) + + +def expand_at_single(x, size, dim): + """ + Expand a tensor at a single dimension @dim by @size + + Args: + x (torch.Tensor): input tensor + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (torch.Tensor): expanded tensor + """ + assert dim < x.ndimension() + assert x.shape[dim] == 1 + expand_dims = [-1] * x.ndimension() + expand_dims[dim] = size + return x.expand(*expand_dims) + + +def expand_at(x, size, dim): + """ + Expand all tensors in nested dictionary or list or tuple at a single + dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) + + +def unsqueeze_expand_at(x, size, dim): + """ + Unsqueeze and expand a tensor at a dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to unsqueeze and expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze(x, dim) + return expand_at(x, size, dim) + + +def repeat_by_expand_at(x, repeats, dim): + """ + Repeat a dimension by combining expand and reshape operations. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + repeats (int): number of times to repeat the target dimension + dim (int): dimension to repeat on + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze_expand_at(x, repeats, dim + 1) + return join_dimensions(x, dim, dim + 2) + + +def named_reduce_single(x, reduction, dim): + """ + Reduce tensor at a dimension by named reduction functions. + + Args: + x (torch.Tensor): tensor to be reduced + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (torch.Tensor): reduced tensor + """ + assert x.ndimension() > dim + assert reduction in ["sum", "max", "mean", "flatten"] + if reduction == "flatten": + x = flatten(x, begin_axis=dim) + elif reduction == "max": + x = torch.max(x, dim=dim)[0] # [B, D] + elif reduction == "sum": + x = torch.sum(x, dim=dim) + else: + x = torch.mean(x, dim=dim) + return x + + +def named_reduce(x, reduction, dim): + """ + Reduces all tensors in nested dictionary or list or tuple at a dimension + using a named reduction function. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d) + ) + + +def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): + """ + This function indexes out a target dimension of a tensor in a structured way, + by allowing a different value to be selected for each member of a flat index + tensor (@indices) corresponding to a source dimension. This can be interpreted + as moving along the source dimension, using the corresponding index value + in @indices to select values for all other dimensions outside of the + source and target dimensions. A common use case is to gather values + in target dimension 1 for each batch member (target dimension 0). + + Args: + x (torch.Tensor): tensor to gather values for + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out + """ + assert len(indices.shape) == 1 + assert x.shape[source_dim] == indices.shape[0] + + # unsqueeze in all dimensions except the source dimension + new_shape = [1] * x.ndimension() + new_shape[source_dim] = -1 + indices = indices.reshape(*new_shape) + + # repeat in all dimensions - but preserve shape of source dimension, + # and make sure target_dimension has singleton dimension + expand_shape = list(x.shape) + expand_shape[source_dim] = -1 + expand_shape[target_dim] = 1 + indices = indices.expand(*expand_shape) + + out = x.gather(dim=target_dim, index=indices) + return out.squeeze(target_dim) + + +def gather_from_start_single(tensor, indices): + if tensor.ndim < indices.ndim: + return tensor + gather_dim = indices.ndim - 1 + + desired_shape = list(indices.shape) + ([1] * (tensor.ndim - gather_dim - 1)) + repeats = [1] * indices.ndim + list(tensor.shape[indices.ndim :]) + indices = indices.reshape(desired_shape).repeat(repeats) + if isinstance(tensor, torch.Tensor): + return torch.gather(tensor, gather_dim, indices) + elif isinstance(tensor, np.ndarray): + return np.take(tensor, indices, axis=gather_dim) + + +def gather_from_start(tensor, indices): + return recursive_dict_list_tuple_apply( + tensor, + { + torch.Tensor: lambda x: gather_from_start_single(x, indices), + np.ndarray: lambda x: gather_from_start_single(x, indices), + type(None): lambda x: x, + }, + ) + + +def gather_along_dim_with_dim(x, target_dim, source_dim, indices): + """ + Apply @gather_along_dim_with_dim_single to all tensors in a nested + dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, + lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single( + y, t, s, i + ), + ) + + +def gather_sequence_single(seq, indices): + """ + Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in + the batch given an index for each sequence. + + Args: + seq (torch.Tensor): tensor with leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Return: + y (torch.Tensor): indexed tensor of shape [B, ....] + """ + return gather_along_dim_with_dim_single( + seq, target_dim=1, source_dim=0, indices=indices + ) + + +def gather_sequence(seq, indices): + """ + Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch + for tensors with leading dimensions [B, T, ...]. + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Returns: + y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] + """ + return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) + + +def slice_tensor_single(x, dim, start_idx, end_idx): + """select a slice of the tensor + + Args: + x (torch.Tensor or np.array): the tensor + dim (int): dimension to select + start_idx (int): starting index + end_idx (int): ending index + """ + assert start_idx >= 0 and start_idx <= end_idx and end_idx <= x.shape[dim] + if isinstance(x, np.ndarray): + return x.take(np.arange(start_idx, end_idx), dim) + elif isinstance(x, torch.Tensor): + return torch.index_select(x, dim, torch.arange(start_idx, end_idx).to(x.device)) + + +def slice_tensor(tensor, dim, start_idx, end_idx): + """recursively select a slice of the tensor or its field if tensor is a dict + + Args: + tensor (torch.Tensor or dict): the tensor + dim (int): dimension to select + start_idx (int): starting index + end_idx (int): ending index + """ + + return recursive_dict_list_tuple_apply( + tensor, + { + torch.Tensor: lambda x: slice_tensor_single(x, dim, start_idx, end_idx), + np.ndarray: lambda x: slice_tensor_single(x, dim, start_idx, end_idx), + type(None): lambda x: x, + }, + ) + + +def pad_sequence_single(seq, padding, batched=False, pad_same=False, pad_values=0.0): + """ + Pad input tensor or array @seq in the time dimension (dimension 1). + + Args: + seq (np.ndarray or torch.Tensor): sequence to be padded + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (np.ndarray or torch.Tensor) + """ + assert isinstance(seq, (np.ndarray, torch.Tensor)) + assert pad_same or (pad_values is not None) + if pad_values is not None: + assert isinstance(pad_values, float) + repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like + seq_dim = 1 if batched else 0 + + begin_pad = [] + end_pad = [] + + if padding[0] > 0: + if batched: + pad = seq[:, [0]] if pad_same else ones_like_func(seq[:, [0]]) * pad_values + else: + pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values + begin_pad.append(repeat_func(pad, padding[0], seq_dim)) + if padding[1] > 0: + if batched: + pad = ( + seq[:, [-1]] if pad_same else ones_like_func(seq[:, [-1]]) * pad_values + ) + else: + pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values + end_pad.append(repeat_func(pad, padding[1], seq_dim)) + + return concat_func(begin_pad + [seq] + end_pad, seq_dim) + + +def pad_sequence(seq, padding, batched=False, pad_same=False, pad_values=0.0): + """ + Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (dict or list or tuple) + """ + return recursive_dict_list_tuple_apply( + seq, + { + torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + type(None): lambda x: x, + }, + ) + + +def left_right_average(seq): + """Add 1 entry to the seq by averaging the left appended seq and right appended seq + + Args: + seq (np.ndarray or torch.Tensor): + """ + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + seq_left = concat_func((seq[:1], seq), 0) + seq_right = concat_func((seq, seq[-1:]), 0) + return 0.5 * (seq_left + seq_right) + + +def assert_size_at_dim_single(x, size, dim, msg): + """ + Ensure that array or tensor @x has size @size in dim @dim. + + Args: + x (np.ndarray or torch.Tensor): input array or tensor + size (int): size that tensors should have at @dim + dim (int): dimension to check + msg (str): text to display if assertion fails + """ + assert x.shape[dim] == size, msg + + +def assert_size_at_dim(x, size, dim, msg): + """ + Ensure that arrays and tensors in nested dictionary or list or tuple have + size @size in dim @dim. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size that tensors should have at @dim + dim (int): dimension to check + """ + map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) + + +def get_shape(x): + """ + Get all shapes of arrays and tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple that contains each array or + tensor's shape + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.shape, + np.ndarray: lambda x: x.shape, + type(None): lambda x: x, + }, + ) + + +def list_of_flat_dict_to_dict_of_list(list_of_dict): + """ + Helper function to go from a list of flat dictionaries to a dictionary of lists. + By "flat" we mean that none of the values are dictionaries, but are numpy arrays, + floats, etc. + + Args: + list_of_dict (list): list of flat dictionaries + + Returns: + dict_of_list (dict): dictionary of lists + """ + assert isinstance(list_of_dict, list) + dic = collections.OrderedDict() + for i in range(len(list_of_dict)): + for k in list_of_dict[i]: + if k not in dic: + dic[k] = [] + dic[k].append(list_of_dict[i][k]) + return dic + + +def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): + """ + Flatten a nested dict or list to a list. + + For example, given a dict + { + a: 1 + b: { + c: 2 + } + c: 3 + } + + the function would return [(a, 1), (b_c, 2), (c, 3)] + + Args: + d (dict, list): a nested dict or list to be flattened + parent_key (str): recursion helper + sep (str): separator for nesting keys + item_key (str): recursion helper + Returns: + list: a list of (key, value) tuples + """ + items = [] + if isinstance(d, (tuple, list)): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for i, v in enumerate(d): + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) + return items + elif isinstance(d, dict): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for k, v in d.items(): + assert isinstance(k, str) + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) + return items + else: + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + return [(new_key, d)] + + +def time_distributed( + inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs +): + """ + Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the + batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. + Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping + outputs to [B, T, ...]. + + Args: + inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + op: a layer op that accepts inputs + activation: activation to apply at the output + inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op + inputs_as_args (bool) whether to feed input as a args list to the op + kwargs (dict): other kwargs to supply to the op + + Returns: + outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. + """ + batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] + inputs = join_dimensions(inputs, 0, 2) + if inputs_as_kwargs: + outputs = op(**inputs, **kwargs) + elif inputs_as_args: + outputs = op(*inputs, **kwargs) + else: + outputs = op(inputs, **kwargs) + + if activation is not None: + outputs = map_tensor(outputs, activation) + outputs = reshape_dimensions( + outputs, begin_axis=0, end_axis=1, target_dims=(batch_size, seq_len) + ) + return outputs + + +def round_2pi(x): + return (x + np.pi) % (2 * np.pi) - np.pi + + +def cat_list_of_dict(x, dim): + """combining a list of dictionaries to a single dictionary by concatenating the values + + Args: + x (List[Dict[tensor.Torch]]): _description_ + """ + combined_dict = dict() + for k, v in x[0].items(): + if isinstance(v, torch.Tensor): + if v.ndim >= dim: + combined_dict[k] = torch.cat([xi[k] for xi in x], dim=dim) + + elif isinstance(v, dict): + combined_dict[k] = cat_list_of_dict([xi[k] for xi in x], dim) + return combined_dict + + +def block_diag_from_cat(x): + """convert a concatenated array to block diagonal + + Args: + x (Union[torch.Tensor,np.ndarray]): [B,M,n,n] + Returns: + mat: [B,Mxn,Mxn] + """ + if x.ndim == 3: + M = x.shape[1] + Id = torch.eye(M, device=x.device) + slices = [torch.kron(Id[i], x[:, i]) for i in range(M)] + mat = torch.cat(slices, 1) + return mat + elif x.ndim == 4: + bs, M, n, m = x.shape + Id = torch.eye(M, device=x.device) + slices = [ + torch.kron(Id[i], x[:, i].reshape(-1, m)).reshape(bs, n, m * M) + for i in range(M) + ] + mat = torch.cat(slices, 1) + return mat + + +def recursive_mean(x): + # x is a list of dict, potentially multi-level, with torch.Tensor as values, we want to compute the mean. + output = dict() + for k, v in x[0].items(): + if isinstance(v, dict): + output[k] = recursive_mean([xi[k] for xi in x]) + elif isinstance(v, torch.Tensor): + output[k] = torch.stack([xi[k] for xi in x], dim=0).mean(dim=0) + else: + raise ValueError("value must be torch.Tensor or dict") + return output + + +def flatten_dict(x): + res = dict() + for k, v in x.items(): + if isinstance(v, dict): + res.update(flatten_dict(v)) + else: + res[k] = v + return res diff --git a/diffstack/utils/timer.py b/diffstack/utils/timer.py new file mode 100644 index 0000000..c0d299e --- /dev/null +++ b/diffstack/utils/timer.py @@ -0,0 +1,65 @@ + +import time +import numpy as np +from contextlib import contextmanager + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.times = [] + + def recent_average_time(self, latest_n): + return np.mean(np.array(self.times)[-latest_n:]) + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.times.append(self.diff) + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + @contextmanager + def timed(self): + self.tic() + yield + self.toc() + + +class Timers(object): + def __init__(self): + self._timers = {} + + def tic(self, key): + if key not in self._timers: + self._timers[key] = Timer() + self._timers[key].tic() + + def toc(self, key): + self._timers[key].toc() + + @contextmanager + def timed(self, key): + self.tic(key) + yield + self.toc(key) + + def __str__(self): + msg = [] + for k, v in self._timers.items(): + msg.append('%s: %f' % (k, v.average_time)) + return ', '.join(msg) \ No newline at end of file diff --git a/diffstack/utils/torch_utils.py b/diffstack/utils/torch_utils.py new file mode 100644 index 0000000..677d5af --- /dev/null +++ b/diffstack/utils/torch_utils.py @@ -0,0 +1,316 @@ +""" +This file contains some PyTorch utilities. +""" +import numpy as np +import pytorch_lightning as pl +import torch +import torch.optim as optim +import functools +from tqdm.auto import tqdm +import time +from typing import Optional, Union + + +def soft_update(source, target, tau): + """ + Soft update from the parameters of a @source torch module to a @target torch module + with strength @tau. The update follows target = target * (1 - tau) + source * tau. + + Args: + source (torch.nn.Module): source network to push target network parameters towards + target (torch.nn.Module): target network to update + """ + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.copy_(target_param * (1.0 - tau) + param * tau) + + +def hard_update(source, target): + """ + Hard update @target parameters to match @source. + + Args: + source (torch.nn.Module): source network to provide parameters + target (torch.nn.Module): target network to update parameters for + """ + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.copy_(param) + + +def get_torch_device(try_to_use_cuda): + """ + Return torch device. If using cuda (GPU), will also set cudnn.benchmark to True + to optimize CNNs. + + Args: + try_to_use_cuda (bool): if True and cuda is available, will use GPU + + Returns: + device (torch.Device): device to use for models + """ + if try_to_use_cuda and torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + return device + + +def reparameterize(mu, logvar): + """ + Reparameterize for the backpropagation of z instead of q. + This makes it so that we can backpropagate through the sampling of z from + our encoder when feeding the sampled variable to the decoder. + + (See "The reparameterization trick" section of https://arxiv.org/abs/1312.6114) + + Args: + mu (torch.Tensor): batch of means from the encoder distribution + logvar (torch.Tensor): batch of log variances from the encoder distribution + + Returns: + z (torch.Tensor): batch of sampled latents from the encoder distribution that + support backpropagation + """ + # logvar = \log(\sigma^2) = 2 * \log(\sigma) + # \sigma = \exp(0.5 * logvar) + + # clamped for numerical stability + logstd = (0.5 * logvar).clamp(-4, 15) + std = torch.exp(logstd) + + # Sample \epsilon from normal distribution + # use std to create a new tensor, so we don't have to care + # about running on GPU or not + eps = std.new(std.size()).normal_() + + # Then multiply with the standard deviation and add the mean + z = eps.mul(std).add_(mu) + + return z + + +def optimizer_from_optim_params(net_optim_params, net): + """ + Helper function to return a torch Optimizer from the optim_params + section of the config for a particular network. + + Args: + optim_params (Config): optim_params part of algo_config corresponding + to @net. This determines the optimizer that is created. + + net (torch.nn.Module): module whose parameters this optimizer will be + responsible + + Returns: + optimizer (torch.optim.Optimizer): optimizer + """ + return optim.Adam( + params=net.parameters(), + lr=net_optim_params["learning_rate"]["initial"], + weight_decay=net_optim_params["regularization"]["L2"], + ) + + +def lr_scheduler_from_optim_params(net_optim_params, net, optimizer): + """ + Helper function to return a LRScheduler from the optim_params + section of the config for a particular network. Returns None + if a scheduler is not needed. + + Args: + optim_params (Config): optim_params part of algo_config corresponding + to @net. This determines whether a learning rate scheduler is created. + + net (torch.nn.Module): module whose parameters this optimizer will be + responsible + + optimizer (torch.optim.Optimizer): optimizer for this net + + Returns: + lr_scheduler (torch.optim.lr_scheduler or None): learning rate scheduler + """ + lr_scheduler = None + if len(net_optim_params["learning_rate"]["epoch_schedule"]) > 0: + # decay LR according to the epoch schedule + lr_scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, + milestones=net_optim_params["learning_rate"]["epoch_schedule"], + gamma=net_optim_params["learning_rate"]["decay_factor"], + ) + return lr_scheduler + + +def backprop_for_loss(net, optim, loss, max_grad_norm=None, retain_graph=False): + """ + Backpropagate loss and update parameters for network with + name @name. + + Args: + net (torch.nn.Module): network to update + + optim (torch.optim.Optimizer): optimizer to use + + loss (torch.Tensor): loss to use for backpropagation + + max_grad_norm (float): if provided, used to clip gradients + + retain_graph (bool): if True, graph is not freed after backward call + + Returns: + grad_norms (float): average gradient norms from backpropagation + """ + + # backprop + optim.zero_grad() + loss.backward(retain_graph=retain_graph) + + # gradient clipping + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm) + + # compute grad norms + grad_norms = 0.0 + for p in net.parameters(): + # only clip gradients for parameters for which requires_grad is True + if p.grad is not None: + grad_norms += p.grad.data.norm(2).pow(2).item() + + # step + optim.step() + + return grad_norms + + +class dummy_context_mgr: + """ + A dummy context manager - useful for having conditional scopes (such + as @maybe_no_grad). Nothing happens in this scope. + """ + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_value, traceback): + return False + + +def maybe_no_grad(no_grad): + """ + Args: + no_grad (bool): if True, the returned context will be torch.no_grad(), otherwise + it will be a dummy context + """ + return torch.no_grad() if no_grad else dummy_context_mgr() + + +def rgetattr(obj, attr, *args): + "recursively get attributes" + + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def rsetattr(obj, attr, val): + "recursively set attributes" + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +class ProgressBar(pl.Callback): + def __init__( + self, global_progress: bool = True, leave_global_progress: bool = True + ): + super().__init__() + + self.global_progress = global_progress + self.global_desc = "Epoch: {epoch}/{max_epoch}" + self.leave_global_progress = leave_global_progress + self.global_pb = None + + def on_fit_start(self, trainer, pl_module): + desc = self.global_desc.format( + epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs + ) + + self.global_pb = tqdm( + desc=desc, + total=trainer.max_epochs, + initial=trainer.current_epoch, + leave=self.leave_global_progress, + disable=not self.global_progress, + ) + + def on_fit_end(self, trainer, pl_module): + self.global_pb.close() + self.global_pb = None + + def on_epoch_end(self, trainer, pl_module): + + # Set description + desc = self.global_desc.format( + epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs + ) + self.global_pb.set_description(desc) + + # Set logs and metrics + # logs = pl_module.logs + # for k, v in logs.items(): + # if isinstance(v, torch.Tensor): + # logs[k] = v.squeeze().item() + # self.global_pb.set_postfix(logs) + + # Update progress + self.global_pb.update(1) + +def tic(timer: bool = True) -> Union[None, float]: + """Use to compute time for time-consuming process, call it before .toc()""" + + start_time = None + if timer: + torch.cuda.synchronize() + start_time: float = time.time() + + return start_time + + +def toc( + start_time: float, name: str = "", timer: bool = True, log=None +) -> Optional[float]: + """Use to compute time for time-consuming process, call it after .tic()""" + + if timer: + torch.cuda.synchronize() + end_time: float = time.time() + elapsed_ms: float = (end_time - start_time) * 1000 + print_str: str = f"{name:30s} EP: {elapsed_ms:.2f} ms" + + if log is not None: + print_log(print_str, log=log, display=False) + else: + print(print_str) + + return elapsed_ms + + return None + + +def print_log(print_str, log, same_line=False, display=True): + ''' + print a string to a log file + + parameters: + print_str: a string to print + log: a opened file to save the log + same_line: True if we want to print the string without a new next line + display: False if we want to disable to print the string onto the terminal + ''' + if display: + if same_line: print('{}'.format(print_str), end='') + else: print('{}'.format(print_str)) + + if same_line: log.write('{}'.format(print_str)) + else: log.write('{}\n'.format(print_str)) + log.flush() + \ No newline at end of file diff --git a/diffstack/utils/tpp_utils.py b/diffstack/utils/tpp_utils.py new file mode 100644 index 0000000..4744c33 --- /dev/null +++ b/diffstack/utils/tpp_utils.py @@ -0,0 +1,994 @@ +from collections import defaultdict +import numpy as np +from scipy.interpolate import interp1d +import torch + +# from diffstack.models.cnn_roi_encoder import rasterized_ROI_align +TRAJ_INDEX = [0, 1, 4] +STATE_INDEX = [0, 1, 4, 2] +INPUT_INDEX = [3, 5] +from Pplan.Sampling.tree import Tree +from Pplan.Sampling.trajectory_tree import TrajTree + +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.planning_utils import get_drivable_area_loss +from diffstack.modules.cost_functions.tpp_internal_costs import TPPInternalCost + + +class AgentTrajTree(Tree): + def __init__(self, traj, parent, depth, prob=None): + self.traj = traj + self.children = list() + self.parent = parent + if parent is not None: + parent.expand(self) + self.depth = depth + self.prob = prob + self.attribute = dict() + + +# The state in Pplan contains more higher order derivatives, TRAJ_INDEX selects x,y, and heading +# out of the longer state vector + + +def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types +): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B, N, T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + # veh_mask = (raw_types >= 3) & (raw_types <= 13) + # ped_mask = (raw_types == 14) | (raw_types == 15) + veh_mask = raw_types == 1 + ped_mask = raw_types == 2 + + edges = torch.zeros([B, N, A, T, 10]).to(ego_trajectories.device) + edges[..., :3] = ego_trajectories.unsqueeze(2).repeat(1, 1, A, 1, 1) + if agent_trajectories.ndim == 4: + edges[..., 3:6] = agent_trajectories.unsqueeze(1).repeat(1, N, 1, 1, 1) + else: + edges[..., 3:6] = agent_trajectories + edges[..., 6:8] = ego_extents.reshape(B, 1, 1, 1, 2).repeat(1, N, A, T, 1) + edges[..., 8:] = agent_extents.reshape(B, 1, A, 1, 2).repeat(1, N, 1, T, 1) + type_mask = {"VV": veh_mask, "VP": ped_mask} + return edges, type_mask + + +def get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + prob=None, + col_funcs=None, +): + """Get veh-veh and veh-ped collision loss.""" + # with torch.no_grad(): + ego_edges, type_mask = gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ) + if col_funcs is None: + col_funcs = { + "VV": GeoUtils.VEH_VEH_collision, + "VP": GeoUtils.VEH_PED_collision, + } + B, N, T = ego_trajectories.shape[:3] + col_loss = torch.zeros([B, N]).to(ego_trajectories.device) + for et, func in col_funcs.items(): + dis = func( + ego_edges[..., 0:3], + ego_edges[..., 3:6], + ego_edges[..., 6:8], + ego_edges[..., 8:], + ) + if dis.nelement() > 0: + col_loss += ( + torch.sum( + torch.sigmoid(-dis * 4) * type_mask[et][:, None, :, None], + dim=[2, 3], + ) + / T + ) + + return col_loss + + +# def get_drivable_area_loss( +# ego_trajectories, raster_from_agent, dis_map, ego_extents +# ): +# """Cost for road departure.""" +# with torch.no_grad(): + +# lane_flags = rasterized_ROI_align( +# dis_map, +# ego_trajectories[..., :2], +# ego_trajectories[..., 2:], +# raster_from_agent, +# torch.ones(*ego_trajectories.shape[:3] +# ).to(ego_trajectories.device), +# ego_extents.unsqueeze(1).repeat(1, ego_trajectories.shape[1], 1), +# 1, +# ).squeeze(-1) +# return lane_flags.max(dim=-1)[0] + + +def get_lane_loss_simple(ego_trajectories, raster_from_agent, dis_map): + h, w = dis_map.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points( + ego_trajectories[..., :2], raster_from_agent + ) + raster_xy[..., 0] = raster_xy[..., 0].clip(0, w - 1e-5) + raster_xy[..., 1] = raster_xy[..., 1].clip(0, h - 1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = raster_xy[..., 1] * w + raster_xy[..., 0] + raster_xy_flat = raster_xy_flat.flatten() + lane_loss = (dis_map.flatten()[raster_xy_flat]).reshape(*raster_xy.shape[:2]) + return lane_loss.max(dim=-1)[0] + + +def get_lane_loss_vectorized(ego_trajectories, lane_info, ego_extents): + Ne, T = ego_trajectories.shape[:2] + cost = torch.zeros(Ne, device=ego_trajectories.device) + + if "leftbdry" in lane_info: + delta_x, delta_y, _ = GeoUtils.batch_proj( + ego_trajectories.reshape(-1, 3), + TensorUtils.to_torch(lane_info["leftbdry"], device=ego_extents.device)[ + None + ].repeat_interleave(Ne * T, 0), + ) + idx = delta_x.abs().argmin(1) + leftmargin = ( + -delta_y.gather(1, idx.reshape(-1, 1)).reshape(Ne, T) - ego_extents[1] / 2 + ) + cost += -(leftmargin.min(1)[0]).clamp(max=0) + if "rightbdry" in lane_info: + delta_x, delta_y, _ = GeoUtils.batch_proj( + ego_trajectories.reshape(-1, 3), + TensorUtils.to_torch(lane_info["rightbdry"], device=ego_extents.device)[ + None + ].repeat_interleave(Ne * T, 0), + ) + idx = delta_x.abs().argmin(1) + rightmargin = ( + delta_y.gather(1, idx.reshape(-1, 1)).reshape(Ne, T) - ego_extents[1] / 2 + ) + cost += -(rightmargin.min(1)[0]).clamp(max=0) + return cost + + +def get_terminal_likelihood_reward(ego_trajectories, raster_from_agent, log_likelihood): + """Cost for road departure.""" + + log_likelihood = (log_likelihood - log_likelihood.mean()) / log_likelihood.std() + h, w = log_likelihood.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points( + ego_trajectories[..., -1, :2], raster_from_agent + ) + raster_xy[..., 0] = raster_xy[..., 0].clip(0, w - 1e-5) + raster_xy[..., 1] = raster_xy[..., 1].clip(0, h - 1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = raster_xy[..., 1] * w + raster_xy[..., 0] + + ll_reward = log_likelihood.flatten()[raster_xy_flat] + return ll_reward + + +def get_progress_reward(ego_trajectories, d_sat=10): + dis = torch.linalg.norm( + ego_trajectories[..., -1, :2] - ego_trajectories[..., 0, :2], dim=-1 + ) + return 2 / np.pi * torch.atan(dis / d_sat) + + +def get_total_distance(ego_trajectories): + """Reward that incentivizes progress.""" + # Assume format [..., T, 3] + assert ego_trajectories.shape[-1] == 3 + diff = ego_trajectories[..., 1:, :] - ego_trajectories[..., :-1, :] + dist = torch.norm(diff[..., :2], dim=-1) + total_dist = torch.sum(dist, dim=-1) + return total_dist + + +def ego_sample_planning( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + raster_from_agent, + dis_map, + weights, + log_likelihood=None, + col_funcs=None, +): + """A basic cost function for prediction-and-planning""" + col_loss = get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + col_funcs, + ) + # lane_loss = get_drivable_area_loss( + # ego_trajectories, raster_from_agent, dis_map, ego_extents + # ) + progress = get_total_distance(ego_trajectories) + + log_likelihood = 0 if log_likelihood is None else log_likelihood + if log_likelihood.ndim == 3: + log_likelihood = get_terminal_likelihood_reward( + ego_trajectories, raster_from_agent, log_likelihood + ) + + total_score = ( + +weights["likelihood_weight"] * log_likelihood + + weights["progress_weight"] * progress + - weights["collision_weight"] * col_loss + # - weights["lane_weight"] * lane_loss + ) + + return torch.argmax(total_score, dim=1) + + +class TreeMotionPolicy(object): + """A trajectory tree policy as the result of contingency planning""" + + def __init__( + self, + stage, + num_frames_per_stage, + ego_root, + scenario_root, + cost_to_go, + leaf_idx, + curr_node, + ): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_root = ego_root + self.scenario_root = scenario_root + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node = curr_node + + def identify_branch(self, ego_node, scene_traj): + assert scene_traj.shape[-2] < self.stage * self.num_frames_per_stage + assert ego_node.total_traj.shape[0] - 1 >= scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + ego_leaf_index = self.leaf_idx[ego_node] + while remain_traj.shape[1] > 0: + seg_length = min(remain_traj.shape[-2], self.num_frames_per_stage) + dis = [ + torch.linalg.norm( + child.traj[ego_leaf_index, :, :seg_length, :2] + - remain_traj[:, :seg_length, :2], + dim=-1, + ) + .sum() + .item() + for child in curr_scenario_node.children + ] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[..., seg_length:, :] + remain_num_frames = curr_scenario_node.traj.shape[-2] - seg_length + if curr_scenario_node.stage >= self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self, scene_traj, horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0] - 1 - T + assert remain_num_frames > -self.num_frames_per_stage + if remain_num_frames <= 0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node, scene_traj) + Q = [ + self.cost_to_go[(child, curr_scenario_node)] + for child in self.curr_node.children + ] + idx = torch.argmin(torch.tensor(Q)).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames += self.curr_node.traj.shape[0] + state = self.curr_node.traj[-remain_num_frames:, STATE_INDEX] + action = self.curr_node.traj[-remain_num_frames:, INPUT_INDEX] + if not self.curr_node.isleaf(): + state = torch.cat( + (state, self.curr_node.children[0].traj[:, STATE_INDEX]), -2 + ) + action = torch.cat( + (action, self.curr_node.children[0].traj[:, INPUT_INDEX]), -2 + ) + + if state.shape[0] >= horizon: + return state[:horizon], action[:horizon] + else: + state_patched = torch.cat( + (state, state[-1].tile(horizon - state.shape[0], 1)) + ) + action_patched = torch.cat( + ( + action, + torch.zeros_like(action[-1]).tile(horizon - action.shape[0], 1), + ) + ) + return state_patched, action_patched + + +class VectorizedTreeMotionPolicy(TreeMotionPolicy): + """A vectorized trajectory tree policy as the result of contingency planning""" + + def __init__( + self, + stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + cost_to_go, + leaf_idx, + curr_node, + ): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_tree = ego_tree + self.ego_root = ego_tree[0][0] + self.children_indices = children_indices + self.scenario_tree = scenario_tree + self.scenario_root = scenario_tree[0][0] + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node = curr_node + + def identify_branch(self, ego_node, scene_traj): + assert scene_traj.shape[-2] < self.stage * self.num_frames_per_stage + assert ego_node.total_traj.shape[0] - 1 >= scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + + stage = ego_node.depth + ego_stage_index = self.ego_tree[stage].index(ego_node) + ego_leaf_index = self.leaf_idx[stage][ego_stage_index].item() + while remain_traj.shape[1] > 0: + seg_length = min(remain_traj.shape[-2], self.num_frames_per_stage) + dis = [ + torch.linalg.norm( + child.traj[ego_leaf_index, :, :seg_length, :2] + - remain_traj[:, :seg_length, :2], + dim=-1, + ) + .sum() + .item() + for child in curr_scenario_node.children + ] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[..., seg_length:, :] + remain_num_frames = curr_scenario_node.traj.shape[-2] - seg_length + if curr_scenario_node.stage >= self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self, scene_traj, horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0] - 1 - T + assert remain_num_frames > -self.num_frames_per_stage + if remain_num_frames <= 0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node, scene_traj) + assert curr_scenario_node.depth == self.curr_node.depth + stage = self.curr_node.depth + scene_node_idx = self.scenario_tree[stage].index(curr_scenario_node) + curr_node_idx = self.ego_tree[stage].index(self.curr_node) + Q = self.cost_to_go[stage][ + scene_node_idx, self.children_indices[stage][curr_node_idx] + ] + idx = torch.argmin(Q).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames += self.curr_node.traj.shape[0] + + state = self.curr_node.traj[-remain_num_frames:, STATE_INDEX] + action = self.curr_node.traj[-remain_num_frames:, INPUT_INDEX] + if not self.curr_node.isleaf(): + state = torch.cat( + (state, self.curr_node.children[0].traj[:, STATE_INDEX]), -2 + ) + action = torch.cat( + (action, self.curr_node.children[0].traj[:, INPUT_INDEX]), -2 + ) + if state.shape[0] >= horizon: + return state[:horizon], action[:horizon] + else: + state_patched = torch.cat( + (state, state[-1].tile(horizon - state.shape[0], 1)) + ) + action_patched = torch.cat( + ( + action, + torch.zeros_like(action[-1]).tile(horizon - action.shape[0], 1), + ) + ) + return state_patched, action_patched + + def get_traj_array(self): + xu_batch = list() + root_xu = torch.cat( + (self.ego_root.traj[:, STATE_INDEX], self.ego_root.traj[:, INPUT_INDEX]), -1 + ) + for branch in self.ego_tree[1]: + xu = [root_xu] + while True: + xu.append( + torch.cat( + (branch.traj[:, STATE_INDEX], branch.traj[:, INPUT_INDEX]), -1 + ) + ) + if branch.isleaf(): + break + else: + branch = branch.children[0] + xu = torch.cat(xu, 0) + xu_batch.append(xu) + return torch.stack(xu_batch, 0) + + +def tiled_to_tree(total_traj, prob, num_stage, num_frames_per_stage, M): + """Turning a trajectory tree in tiled form to a tree data structure + + Args: + total_traj (torch.tensor or np.ndarray): tiled trajectory tree + prob (torch.tensor or np.ndarray): probability of the modes + num_stage (int): number of layers of the tree + num_frames_per_stage (int): number of time frames per layer + M (int): branching factor + + Returns: + nodes (dict[int:List(AgentTrajTree)]): all branches of the trajectory tree nodes indexed by layer + """ + + # total_traj = TensorUtils.reshape_dimensions_single(total_traj,2,3,[M]*num_stage) + x0 = AgentTrajTree(None, None, 0) + nodes = defaultdict(lambda: list()) + nodes[0].append(x0) + for t in range(num_stage): + interval = M ** (num_stage - t - 1) + tiled_traj = total_traj[ + ..., + ::interval, + :, + t * num_frames_per_stage : (t + 1) * num_frames_per_stage, + :, + ] + for i in range(M ** (t + 1)): + parent_idx = int(i / M) + p = prob[:, i * interval : (i + 1) * interval].sum(-1) + node = AgentTrajTree(tiled_traj[:, i], nodes[t][parent_idx], t + 1, prob=p) + nodes[t + 1].append(node) + return nodes + + +def contingency_planning( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std=None, +): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + device = agent_traj.device + leaf_idx = defaultdict(lambda: list()) + for stage in range(num_stage, -1, -1): + for node in ego_tree[stage]: + if node.isleaf(): + leaf_idx[node] = [ego_tree[stage].index(node)] + else: + leaf_idx[node] = [] + for child in node.children: + leaf_idx[node] = leaf_idx[node] + leaf_idx[child] + + V = dict() + L = dict() + Q = dict() + scenario_tree = tiled_to_tree( + agent_traj, mode_prob, num_stage, num_frames_per_stage, M + ) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + for stage in range(num_stage, 0, -1): + if stage == 0: + total_loss = torch.zeros([1, 1], device=device) + else: + ego_nodes = ego_tree[stage] + indices = [leaf_idx[node][0] for node in ego_nodes] + ego_traj = [node.traj[:, TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj, 0) + agent_nodes = scenario_tree[stage] + agent_traj = [node.traj[indices] for node in agent_nodes] + agent_traj = torch.stack(agent_traj, 0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes), 1, 1, 1) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes), 1), + agent_extents.tile(len(agent_nodes), 1, 1), + agent_types.tile(len(agent_nodes), 1), + col_funcs, + ) + + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + + total_loss = weights["collision_weight"] * col_loss - weights[ + "progress_weight" + ] * progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[1], device=device).unsqueeze(0) + * pert_std + ) + if log_likelihood is not None and stage == num_stage: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + + for i in range(len(ego_nodes)): + for j in range(len(agent_nodes)): + L[(ego_nodes[i], agent_nodes[j])] = total_loss[j, i] + if stage == num_stage: + V[(ego_nodes[i], agent_nodes[j])] = float(total_loss[j, i]) + else: + children_cost_to_go = [ + Q[(child, agent_nodes[j])] for child in ego_nodes[i].children + ] + V[(ego_nodes[i], agent_nodes[j])] = float(total_loss[j, i]) + min( + children_cost_to_go + ) + + if stage > 0: + for agent_node in scenario_tree[stage - 1]: + cost_i = [] + prob_i = [] + for child in agent_node.children: + cost_i.append(V[ego_nodes[i], child]) + prob_i.append(child.prob[leaf_idx[ego_nodes[i]]].sum()) + cost_i = torch.tensor(cost_i, device=device) + prob_i = torch.stack(prob_i) + Q[(ego_nodes[i], agent_node)] = float( + (cost_i * prob_i).sum() / prob_i.sum() + ) + Q_root = [Q[(child, scenario_root)] for child in ego_root.children] + idx = torch.argmin(torch.tensor(Q_root)).item() + optimal_node = ego_root.children[idx] + motion_policy = TreeMotionPolicy( + num_stage, + num_frames_per_stage, + ego_root, + scenario_root, + Q, + leaf_idx, + optimal_node, + ) + motion_policy.get_plan(None, num_stage * num_frames_per_stage) + return motion_policy + + +def get_cost_for_trajs(xu_batch, agent_traj, cost_obj, goal, lanes): + bs = len(xu_batch) + numMode = agent_traj.shape[0] + if agent_traj.nelement() == 0: + dummy_shape = list(agent_traj.shape) + dummy_shape[2] = 1 + agent_traj = torch.ones(dummy_shape, device=agent_traj.device) * 1e3 + + # Tile batch for each prediction mode, and input multi-modal predictions as pred_singles + ego_xu_tiled = xu_batch.repeat_interleave(numMode, 0) + pred_singles = TensorUtils.join_dimensions(agent_traj[..., :2], 0, 2) + pred_mus = torch.zeros( + (bs * numMode, 0, pred_singles.shape[-2], 1, 2), device=xu_batch.device + ) # b, N, T, K, 2 + pred_probs = torch.zeros( + [bs * numMode, 1, pred_singles.shape[-2]], device=xu_batch.device + ) + goal = ( + goal.unsqueeze(0).repeat_interleave(numMode * bs, 0) + if goal is not None + else None + ) + lanes = ( + lanes.unsqueeze(0).repeat_interleave(numMode * bs, 0) + if lanes is not None + else None + ) + + cost_inputs = (pred_mus, pred_probs, pred_singles, goal, lanes) + cost_inputs = TensorUtils.to_device(cost_inputs, ego_xu_tiled.device) + traj_cost = cost_obj(ego_xu_tiled, cost_inputs) # b, T + + # sum over time + traj_cost = traj_cost.sum(1) + # recover batch and prediction modes + traj_cost = traj_cost.reshape(bs, numMode) + + return traj_cost + + +def contingency_planning_parallel( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + lane_info, + weights, + num_frames_per_stage, + M, + dt, + cost_obj=None, + lanes=None, + goal=None, + col_funcs=None, + log_likelihood=None, + lane_type="rasterized", + pert_std=None, +): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + device = agent_traj.device + children_indices = TrajTree.get_children_index_torch(ego_tree) + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + + leaf_idx = {num_stage: torch.arange(len(ego_tree[num_stage]), device=device)} + stage_prob = {num_stage: mode_prob.T} + for stage in range(num_stage - 1, -1, -1): + leaf_idx[stage] = leaf_idx[stage + 1][children_indices[stage][:, 0]] + prob_next = stage_prob[stage + 1] + stage_prob[stage] = prob_next.reshape(-1, M, prob_next.shape[-1])[ + :, :, children_indices[stage][:, 0] + ].sum(1) + + V = dict() + L = dict() + Q = dict() + + scenario_tree = tiled_to_tree( + agent_traj, mode_prob, num_stage, num_frames_per_stage, M + ) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + + for stage in range(num_stage, -1, -1): + if stage == 0: + total_loss = torch.zeros([1, 1], device=device) + else: + # calculate stage cost + ego_nodes = ego_tree[stage] + ego_traj = [node.traj[:, TRAJ_INDEX] for node in ego_nodes] + + ego_traj = torch.stack(ego_traj, 0) + agent_nodes = scenario_tree[stage] + + agent_traj = [node.traj[leaf_idx[stage]] for node in agent_nodes] + + agent_traj = torch.stack(agent_traj, 0) + if cost_obj is None or isinstance(cost_obj, TPPInternalCost): + # use internal TPP cost + if agent_traj.nelement() == 0: + col_loss = torch.zeros( + [*agent_traj.shape[:2]], device=ego_traj.device + ) + else: + ego_traj_tiled = ego_traj.unsqueeze(0).repeat( + len(agent_nodes), 1, 1, 1 + ) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes), 1), + agent_extents.tile(len(agent_nodes), 1, 1), + agent_types.tile(len(agent_nodes), 1), + col_funcs, + ) + + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + if lane_type == "rasterized": + lane_loss = get_lane_loss_simple( + ego_traj, raster_from_agent, lane_info + ).unsqueeze(0) + elif lane_type == "vectorized": + lane_loss = get_lane_loss_vectorized( + ego_traj, lane_info, ego_extents + ).unsqueeze(0) + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + + total_loss = ( + weights["collision_weight"] * col_loss + - weights["progress_weight"] * progress_reward.unsqueeze(0) + + weights["lane_weight"] * lane_loss + ) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[1], device=device).unsqueeze(0) + * pert_std + ) + if log_likelihood is not None and stage == num_stage: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + else: + # use diffstack cost + ego_x = torch.stack( + [node.traj[:, STATE_INDEX] for node in ego_nodes], 0 + ) + ego_u = torch.stack( + [node.traj[:, INPUT_INDEX] for node in ego_nodes], 0 + ) + ego_x_pre = torch.stack( + [node.parent.traj[-1:, STATE_INDEX] for node in ego_nodes], 0 + ) + ego_u_pre = torch.stack( + [node.parent.traj[-1:, INPUT_INDEX] for node in ego_nodes], 0 + ) + ego_xu = torch.cat((ego_x, ego_u), -1) + ego_xu_pre = torch.cat((ego_x_pre, ego_u_pre), -1) + ego_xu = torch.cat((ego_xu_pre, ego_xu), 1) + lane_seg = lanes[ + (stage - 1) * num_frames_per_stage : stage * num_frames_per_stage + + 1 + ] + + total_loss = get_cost_for_trajs( + ego_xu, agent_traj, cost_obj, goal, lane_seg + ).transpose(0, 1) + + L[stage] = total_loss + if stage == num_stage: + V[stage] = total_loss + else: + children_idx = children_indices[stage] + # add the last Q value as inf since empty children index are padded with -1 + Q_prime = torch.cat( + (Q[stage], torch.full([Q[stage].shape[0], 1], np.inf, device=device)), 1 + ) + Q_by_node = Q_prime[:, children_idx] + V[stage] = total_loss + Q_by_node.min(dim=-1)[0] + + if stage > 0: + children_V = V[stage] + children_V = children_V.reshape(-1, M, children_V.shape[-1]) + prob = stage_prob[stage] + prob = prob.reshape(-1, M, prob.shape[-1]) + prob_normalized = prob / prob.sum(dim=1, keepdim=True) + Q[stage - 1] = (children_V * prob_normalized).sum(dim=1) + + idx = Q[0].argmin().item() + + motion_policy = VectorizedTreeMotionPolicy( + num_stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + Q, + leaf_idx, + ego_root.children[idx], + ) + return motion_policy + + +def one_shot_planning( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std=None, + strategy="all", +): + """Alternative of contingency planning, try to avoid all predicted trajectories + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + assert strategy == "all" or strategy == "maximum" + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + + ego_traj = [node.total_traj[1:, TRAJ_INDEX] for node in ego_tree[num_stage]] + ego_traj = torch.stack(ego_traj, 0) + ego_traj_tiled = ego_traj.unsqueeze(1).repeat_interleave(agent_traj.shape[1], 1) + Ne = ego_traj.shape[0] + if strategy == "maximum": + idx = mode_prob.argmax(dim=1) + idx = idx.reshape(Ne, *[1] * (agent_traj.ndim - 1)) + agent_traj = agent_traj.take_along_dim(idx, 1) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(Ne, 1), + agent_extents.tile(Ne, 1, 1), + agent_types.tile(Ne, 1), + col_funcs, + ) + col_loss = col_loss.max(dim=1)[0] + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)).squeeze(0) + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + total_loss = ( + weights["collision_weight"] * col_loss + - weights["progress_weight"] * progress_reward + ) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[0], device=total_loss.device) * pert_std + ) + if log_likelihood is not None: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + + idx = total_loss.argmin() + return ego_traj[idx] + + +def obtain_ref(line, x, v, N, dt): + """obtain desired trajectory for the MPC controller + + Args: + line (np.ndarray): centerline of the lane [n, 3] + x (np.ndarray): position of the vehicle + v (np.ndarray): desired velocity + N (int): number of time steps + dt (float): time step + + Returns: + refx (np.ndarray): desired trajectory [N,3] + """ + line_length = line.shape[0] + delta_x = line[..., 0:2] - np.repeat(x[..., np.newaxis, 0:2], line_length, axis=-2) + dis = np.linalg.norm(delta_x, axis=-1) + idx = np.argmin(dis, axis=-1) + line_min = line[idx] + dx = x[0] - line_min[0] + dy = x[1] - line_min[1] + delta_y = -dx * np.sin(line_min[2]) + dy * np.cos(line_min[2]) + delta_x = dx * np.cos(line_min[2]) + dy * np.sin(line_min[2]) + refx0 = np.array( + [ + line_min[0] + delta_x * np.cos(line_min[2]), + line_min[1] + delta_x * np.sin(line_min[2]), + line_min[2], + ] + ) + s = [np.linalg.norm(line[idx + 1, 0:2] - refx0[0:2])] + for i in range(idx + 2, line_length): + s.append(s[-1] + np.linalg.norm(line[i, 0:2] - line[i - 1, 0:2])) + f = interp1d( + np.array(s), + line[idx + 1 :], + kind="linear", + axis=0, + copy=True, + bounds_error=False, + fill_value="extrapolate", + assume_sorted=True, + ) + s1 = v * np.arange(1, N + 1) * dt + refx = f(s1) + + return refx diff --git a/diffstack/utils/train_utils.py b/diffstack/utils/train_utils.py new file mode 100644 index 0000000..9349d78 --- /dev/null +++ b/diffstack/utils/train_utils.py @@ -0,0 +1,151 @@ +""" +This file contains several utility functions used to define the main training loop. It +mainly consists of functions to assist with logging, rollouts, and the @run_epoch function, +which is the core training logic for models in this repository. +""" +import os +import socket +import shutil +import pytorch_lightning as pl +from pytorch_lightning.loops.utilities import _reset_progress + +def infinite_iter(data_loader): + """ + Get an infinite generator + Args: + data_loader (DataLoader): data loader to iterate through + + """ + c_iter = iter(data_loader) + while True: + try: + data = next(c_iter) + except StopIteration: + c_iter = iter(data_loader) + data = next(c_iter) + yield data + + +def get_exp_dir(exp_name, output_dir, save_checkpoints=True, auto_remove_exp_dir=False): + """ + Create experiment directory from config. If an identical experiment directory + exists and @auto_remove_exp_dir is False (default), the function will prompt + the user on whether to remove and replace it, or keep the existing one and + add a new subdirectory with the new timestamp for the current run. + + Args: + exp_name (str): name of the experiment + output_dir (str): output directory of the experiment + save_checkpoints (bool): if save checkpoints + auto_remove_exp_dir (bool): if True, automatically remove the existing experiment + folder if it exists at the same path. + + Returns: + log_dir (str): path to created log directory (sub-folder in experiment directory) + output_dir (str): path to created models directory (sub-folder in experiment directory) + to store model checkpoints + video_dir (str): path to video directory (sub-folder in experiment directory) + to store rollout videos + """ + + # create directory for where to dump model parameters, tensorboard logs, and videos + base_output_dir = output_dir + if not os.path.isabs(base_output_dir): + base_output_dir = os.path.abspath(base_output_dir) + base_output_dir = os.path.join(base_output_dir, exp_name) + if os.path.exists(base_output_dir): + # if not auto_remove_exp_dir: + # ans = input( + # "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format( + # base_output_dir + # ) + # ) + # else: + # ans = "y" + # if ans == "y": + # print("REMOVING") + # shutil.rmtree(base_output_dir) + if auto_remove_exp_dir: + print(f"REMOVING {base_output_dir}") + shutil.rmtree(base_output_dir) + os.makedirs(base_output_dir, exist_ok=True) + + # version the run + existing_runs = [ + a + for a in os.listdir(base_output_dir) + if os.path.isdir(os.path.join(base_output_dir, a)) + ] + run_counts = [-1] + for ep in existing_runs: + m = ep.split("run") + if len(m) == 2 and m[0] == "": + run_counts.append(int(m[1])) + version_str = "run{}".format(max(run_counts) + 1) + + # only make model directory if model saving is enabled + ckpt_dir = None + if save_checkpoints: + ckpt_dir = os.path.join(base_output_dir, version_str, "checkpoints") + os.makedirs(ckpt_dir) + + # tensorboard directory + log_dir = os.path.join(base_output_dir, version_str, "logs") + os.makedirs(log_dir) + + # video directory + video_dir = os.path.join(base_output_dir, version_str, "videos") + os.makedirs(video_dir) + return base_output_dir, log_dir, ckpt_dir, video_dir, version_str + + +def trajdata_auto_set_batch_size(trainer: "pl.Trainer", model: "pl.LightningModule",datamodule:"pl.LightningDataModule",bs_min = 2,bs_max = None,conservative_reduction=True) -> int: + if bs_max == None: + # power search to find the bs_max + bs_trial = bs_min + + while True: + + datamodule.train_batch_size = bs_trial + datamodule.val_batch_size = bs_trial + try: + trainer.fit(model=model,datamodule=datamodule) + _reset_progress(trainer.fit_loop) + print(f"batch size {bs_trial} succeeded, trying {bs_trial*2}") + bs_min = bs_trial + bs_trial *= 2 + + + except: + print(f"batch size {bs_trial} failed, setting max batch size to {bs_trial}") + bs_max = bs_trial + break + + # the maximum batch size is dataset size divided by validation interval (there needs to be at least 1 validation per epoch) + if bs_trial >=len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100): + bs_max = int(len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100))-1 + break + else: + bs_max = min(bs_max,int(len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100))-1) + # binary search to find the optimal batch size + print(f" starting binary search with minimum batch size {bs_min}, maximum batch size {bs_max}") + while bs_max - bs_min > 1: + bs_trial = (bs_min + bs_max) // 2 + print(f"trying batch size {bs_trial}") + datamodule.train_batch_size = bs_trial + datamodule.val_batch_size = bs_trial + try: + trainer.fit(model=model,datamodule=datamodule) + _reset_progress(trainer.fit_loop) + print(f"batch size {bs_trial} succeeded") + bs_min = bs_trial + except: + bs_max = bs_trial + print(f"batch size {bs_trial} failed") + if bs_max-bs_min torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + + b, a, t, _ = agent_hist_pos.shape + _, _, _, h, w = maps.shape + maps = maps.clone() + agent_hist_pos = TensorUtils.unsqueeze_expand_at(agent_hist_pos,a,1) + agent_mask_tiled = TensorUtils.unsqueeze_expand_at(agent_mask,a,1)*TensorUtils.unsqueeze_expand_at(agent_mask,a,2) + raster_hist_pos = transform_points_tensor(agent_hist_pos.reshape(b*a,-1,2), raster_from_agent.reshape(b*a,3,3)).reshape(b,a,a,t,2) + raster_hist_pos = raster_hist_pos * agent_mask_tiled.unsqueeze(-1) # Set invalid positions to 0.0 Will correct below + + raster_hist_pos[..., 0].clip_(0, (w - 1)) + raster_hist_pos[..., 1].clip_(0, (h - 1)) + raster_hist_pos = torch.round(raster_hist_pos).long() # round pixels [B, A, A, T, 2] + raster_hist_pos = raster_hist_pos.transpose(2,3) + raster_hist_pos_flat = raster_hist_pos[..., 1] * w + raster_hist_pos[..., 0] # [B, A, T, A] + hist_image = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + + ego_mask = torch.zeros_like(raster_hist_pos_flat,dtype=torch.bool) + ego_mask[:,range(a),:,range(a)]=1 + agent_mask = torch.logical_not(ego_mask) + + + hist_image.scatter_(dim=3, index=raster_hist_pos_flat*agent_mask, src=torch.ones_like(hist_image) * -1) # mark other agents with -1 + hist_image.scatter_(dim=3, index=raster_hist_pos_flat*ego_mask, src=torch.ones_like(hist_image)) # mark ego with 1. + hist_image[..., 0] = 0 # correct the 0th index from invalid positions + hist_image[..., -1] = 0 # correct the maximum index caused by out of bound locations + + hist_image = hist_image.reshape(b, a, t, h, w) + + maps = torch.cat((hist_image, maps), dim=2) # treat time as extra channels + return maps + + +def rasterize_agents_sc( + maps: torch.Tensor, + agent_pos: torch.Tensor, + agent_yaw: torch.Tensor, + agent_speed: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + + b, a, t, _ = agent_pos.shape + _, _, _, h, w = maps.shape + + # take the first agent as the center agent + raster_pos = transform_points_tensor(agent_pos.reshape(b,-1,2), raster_from_agent[:,0]).reshape(b,a,t,2) + raster_pos = raster_pos * agent_mask.unsqueeze(-1) # Set invalid positions to 0.0 Will correct below + + raster_pos[..., 0].clip_(0, (w - 1)) + raster_pos[..., 1].clip_(0, (h - 1)) + raster_pos_round = torch.round(raster_pos).long() # round pixels [B, A, T, 2] + raster_dxy = raster_pos - raster_pos_round.float() + + raster_pos_flat = raster_pos_round[..., 1] * w + raster_pos_round[..., 0] # [B, A, T] + prob = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # dx = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # dy = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # heading = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # vel = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + + raster_pos_flat = raster_pos_flat.unsqueeze(-1) + prob.scatter_(dim=3, index=raster_pos_flat, src=torch.ones_like(prob)) + # dx.scatter_(dim=3, index=raster_pos_flat, src=raster_dxy[...,0:1].clone().repeat_interleave(h*w,-1)) + # dy.scatter_(dim=3, index=raster_pos_flat, src=raster_dxy[...,1:2].clone().repeat_interleave(h*w,-1)) + # heading.scatter_(dim=3, index=raster_pos_flat, src=agent_yaw.repeat_interleave(h*w,-1)) + # vel.scatter_(dim=3, index=raster_pos_flat, src=agent_speed.unsqueeze(-1).repeat_interleave(h*w,-1)) + + # feature = torch.stack((prob, dx, dy, heading), dim=3) # [B, A, T, 5, H * W] + feature = prob + feature[..., 0] = 0 # correct the 0th index from invalid positions + feature[..., -1] = 0 # correct the maximum index caused by out of bound locations + + feature = feature.reshape(b, a, t, -1, h, w) + + return feature + + + +def rasterize_agents( + maps: torch.Tensor, + agent_hist_pos: torch.Tensor, + agent_hist_yaw: torch.Tensor, + agent_extent: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, + map_res: torch.Tensor, + cat=True, + filter=None, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + b, a, t, _ = agent_hist_pos.shape + _, _, h, w = maps.shape + + + agent_hist_pos = agent_hist_pos.reshape(b, a * t, 2) + raster_hist_pos = transform_points_tensor(agent_hist_pos, raster_from_agent) + raster_hist_pos[~agent_mask.reshape(b, a * t)] = 0.0 # Set invalid positions to 0.0 Will correct below + raster_hist_pos = raster_hist_pos.reshape(b, a, t, 2).permute(0, 2, 1, 3) # [B, T, A, 2] + raster_hist_pos[..., 0].clip_(0, (w - 1)) + raster_hist_pos[..., 1].clip_(0, (h - 1)) + raster_hist_pos = torch.round(raster_hist_pos).long() # round pixels + + raster_hist_pos_flat = raster_hist_pos[..., 1] * w + raster_hist_pos[..., 0] # [B, T, A] + + hist_image = torch.zeros(b, t, h * w, dtype=maps.dtype, device=maps.device) # [B, T, H * W] + + hist_image.scatter_(dim=2, index=raster_hist_pos_flat[:, :, 1:], src=torch.ones_like(hist_image) * -1) # mark other agents with -1 + hist_image.scatter_(dim=2, index=raster_hist_pos_flat[:, :, [0]], src=torch.ones_like(hist_image)) # mark ego with 1. + hist_image[:, :, 0] = 0 # correct the 0th index from invalid positions + hist_image[:, :, -1] = 0 # correct the maximum index caused by out of bound locations + + hist_image = hist_image.reshape(b, t, h, w) + if filter=="0.5-1-0.5": + kernel = torch.tensor([[0.5, 0.5, 0.5], + [0.5, 1., 0.5], + [0.5, 0.5, 0.5]]).to(hist_image.device) + + kernel = kernel.view(1, 1, 3, 3).repeat(t, t, 1, 1) + hist_image = F.conv2d(hist_image, kernel,padding=1) + if cat: + maps = maps.clone() + maps = torch.cat((hist_image, maps), dim=1) # treat time as extra channels + return maps + else: + return hist_image + +def rasterize_agents_rec( + maps: torch.Tensor, + agent_hist_pos: torch.Tensor, + agent_hist_yaw: torch.Tensor, + agent_extent: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, + map_res: torch.Tensor, + cat=True, + ego_neg = False, + parallel_raster=True, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + with torch.no_grad(): + b, a, t, _ = agent_hist_pos.shape + _, _, h, w = maps.shape + + coord_tensor = torch.cat((torch.arange(w).view(w,1,1).repeat_interleave(h,1), + torch.arange(h).view(1,h,1).repeat_interleave(w,0),),-1).to(maps.device) + + agent_hist_pos = agent_hist_pos.reshape(b, a * t, 2) + raster_hist_pos = transform_points_tensor(agent_hist_pos, raster_from_agent) + + + raster_hist_pos[~agent_mask.reshape(b, a * t)] = 0.0 # Set invalid positions to 0.0 Will correct below + + raster_hist_pos = raster_hist_pos.reshape(b, a, t, 2).permute(0, 2, 1, 3) # [B, T, A, 2] + + raster_hist_pos_yx = torch.cat((raster_hist_pos[...,1:],raster_hist_pos[...,0:1]),-1) + + if parallel_raster: + # vectorized version, uses much more memory + coord_tensor_tiled = coord_tensor.view(1,1,1,h,w,-1).repeat(b,t,a,1,1,1) + dyx = raster_hist_pos_yx[...,None,None,:]-coord_tensor_tiled + cos_yaw = torch.cos(-agent_hist_yaw) + sin_yaw = torch.sin(-agent_hist_yaw) + + rotM = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw], dim=-1), + torch.stack([-sin_yaw, cos_yaw], dim=-1), + ],dim=-2, + ) + rotM = rotM.transpose(1,2) + rel_yx = torch.matmul(rotM.unsqueeze(-3).repeat(1,1,1,h,w,1,1),dyx.unsqueeze(-1)).squeeze(-1) + agent_extent_yx = torch.cat((agent_extent[...,1:2],agent_extent[...,0:1]),-1) + extent_tiled = agent_extent_yx[:,None,:,None,None] + + flag = (torch.abs(rel_yx)1: + # aggregate along the agent dimension + hist_img = hist_img[:,:,0] + hist_img[:,:,1:].max(2)[0]*(hist_img[:,:,0]==0) + else: + hist_img = hist_img.squeeze(2) + else: + + # loop through all agents, slow but memory efficient + coord_tensor_tiled = coord_tensor.view(1,1,h,w,-1).repeat(b,t,1,1,1) + agent_extent_yx = torch.cat((agent_extent[...,1:2],agent_extent[...,0:1]),-1) + hist_img_ego = torch.zeros([b,t,h,w],device=maps.device) + hist_img_nb = torch.zeros([b,t,h,w],device=maps.device) + for i in range(raster_hist_pos_yx.shape[-2]): + dyx = raster_hist_pos_yx[...,i,None,None,:]-coord_tensor_tiled + yaw_i = agent_hist_yaw[:,i] + cos_yaw = torch.cos(-yaw_i) + sin_yaw = torch.sin(-yaw_i) + + rotM = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw], dim=-1), + torch.stack([-sin_yaw, cos_yaw], dim=-1), + ],dim=-2, + ) + + rel_yx = torch.matmul(rotM.unsqueeze(-3).repeat(1,1,h,w,1,1),dyx.unsqueeze(-1)).squeeze(-1) + extent_tiled = agent_extent_yx[:,None,i,None,None] + + flag = (torch.abs(rel_yx)1: + hist_img = hist_img_ego + hist_img_nb*(hist_img_ego==0) + else: + hist_img = hist_img_ego + + if cat: + maps = maps.clone() + maps = torch.cat((hist_img, maps), dim=1) # treat time as extra channels + return maps + else: + return hist_img + + + +def get_drivable_region_map(maps): + if isinstance(maps, torch.Tensor): + if maps.shape[-3]>=7: + drivable = torch.amax(maps[..., -7:-4, :, :], dim=-3).bool() + else: + drivable = torch.amax(maps[..., -3:, :, :], dim=-3).bool() + else: + if maps.shape[-3]>=7: + drivable = np.amax(maps[..., -7:-4, :, :], axis=-3).astype(bool) + else: + drivable = np.amax(maps[..., -3:, :, :], dim=-3).astype(bool) + return drivable + + +def maybe_pad_neighbor(batch): + """Pad neighboring agent's history to the same length as that of the ego using NaNs""" + hist_len = batch["agent_hist"].shape[1] + fut_len = batch["agent_fut"].shape[1] + b, a, neigh_len, _ = batch["neigh_hist"].shape + empty_neighbor = a == 0 + if empty_neighbor: + batch["neigh_hist"] = torch.ones(b, 1, hist_len, batch["neigh_hist"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_fut"] = torch.ones(b, 1, fut_len, batch["neigh_fut"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_types"] = torch.zeros(b, 1,device=batch["agent_hist"].device) + batch["neigh_hist_extents"] = torch.zeros(b, 1, hist_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + batch["neigh_fut_extents"] = torch.zeros(b, 1, fut_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + elif neigh_len < hist_len: + hist_pad = torch.ones(b, a, hist_len - neigh_len, batch["neigh_hist"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_hist"] = torch.cat((hist_pad, batch["neigh_hist"]), dim=2) + hist_pad = torch.zeros(b, a, hist_len - neigh_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + batch["neigh_hist_extents"] = torch.cat((hist_pad, batch["neigh_hist_extents"]), dim=2) + +def parse_scene_centric(batch: dict, rasterize_mode:str): + fut_pos, fut_yaw, fut_speed, fut_mask = trajdata2posyawspeed(batch["agent_fut"]) + hist_pos, hist_yaw, hist_speed, hist_mask = trajdata2posyawspeed(batch["agent_hist"]) + + curr_pos = hist_pos[:,:,-1] + curr_yaw = hist_yaw[:,:,-1] + if batch["centered_agent_state"].shape[-1]==7: + world_yaw = batch["centered_agent_state"][...,6] + else: + assert batch["centered_agent_state"].shape[-1]==8 + world_yaw = torch.atan2(batch["centered_agent_state"][...,6],batch["centered_agent_state"][...,7]) + curr_speed = hist_speed[..., -1] + centered_state = batch["centered_agent_state"] + centered_yaw = centered_state[:, -1] + centered_pos = centered_state[:, :2] + old_type = batch["agent_type"] + agent_type = torch.zeros_like(old_type) + agent_type[old_type < 0] = 0 + agent_type[old_type ==[3,4]] = 2 + agent_type[old_type ==1] = 3 + agent_type[old_type ==2] = 2 + agent_type[old_type ==5] = 4 + + # mask out invalid extents + agent_hist_extent = batch["agent_hist_extent"] + agent_hist_extent[torch.isnan(agent_hist_extent)] = 0. + + if not batch["centered_agent_from_world_tf"].isnan().any(): + centered_world_from_agent = torch.inverse(batch["centered_agent_from_world_tf"]) + else: + centered_world_from_agent = None + b,a = curr_yaw.shape[:2] + agents_from_center = (GeoUtils.transform_matrices(-curr_yaw.flatten(),torch.zeros(b*a,2,device=curr_yaw.device)) + @GeoUtils.transform_matrices(torch.zeros(b*a,device=curr_yaw.device),-curr_pos.reshape(-1,2))).reshape(*curr_yaw.shape[:2],3,3) + center_from_agents = GeoUtils.transform_matrices(curr_yaw.flatten(),curr_pos.reshape(-1,2)).reshape(*curr_yaw.shape[:2],3,3) + # map-related + if batch["maps"] is not None: + map_res = batch["maps_resolution"][0,0] + h, w = batch["maps"].shape[-2:] + # TODO: pass env configs to here + + centered_raster_from_agent = torch.Tensor([ + [map_res, 0, 0.125 * w], + [0, map_res, 0.5 * h], + [0, 0, 1] + ]).type_as(agents_from_center) + + centered_agent_from_raster,_ = torch.linalg.inv_ex(centered_raster_from_agent) + + raster_from_center = centered_raster_from_agent @ agents_from_center + center_from_raster = center_from_agents @ centered_agent_from_raster + + raster_from_world = batch["rasters_from_world_tf"] + world_from_raster,_ = torch.linalg.inv_ex(raster_from_world) + raster_from_world[torch.isnan(raster_from_world)] = 0. + world_from_raster[torch.isnan(world_from_raster)] = 0. + + if rasterize_mode=="none": + maps = batch["maps"] + elif rasterize_mode=="point": + maps = rasterize_agents_scene( + batch["maps"], + hist_pos, + hist_yaw, + None, + hist_mask, + raster_from_center, + map_res + ) + elif rasterize_mode=="square": + #TODO: add the square rasterization function for scene-centric data + raise NotImplementedError + elif rasterize_mode=="point_sc": + hist_hm = rasterize_agents_sc( + batch["maps"], + hist_pos, + hist_yaw, + hist_speed, + hist_mask, + raster_from_center, + ) + fut_hm = rasterize_agents_sc( + batch["maps"], + fut_pos, + fut_yaw, + fut_speed, + fut_mask, + raster_from_center, + ) + maps = batch["maps"] + drivable_map = get_drivable_region_map(batch["maps"]) + else: + maps = None + drivable_map = None + raster_from_center = None + center_from_raster = None + raster_from_world = None + centered_agent_from_raster = None + centered_raster_from_agent = None + + extent_scale = 1.0 + + + d = dict( + image=maps, + drivable_map=drivable_map, + fut_pos=fut_pos, + fut_yaw=fut_yaw, + fut_mask=fut_mask, + hist_pos=hist_pos, + hist_yaw=hist_yaw, + hist_mask=hist_mask, + curr_speed=curr_speed, + centroid=curr_pos, + world_yaw=world_yaw, + type=agent_type, + extent=agent_hist_extent.max(dim=-2)[0] * extent_scale, + raster_from_agent=centered_raster_from_agent, + agent_from_raster=centered_agent_from_raster, + raster_from_center=raster_from_center, + center_from_raster=center_from_raster, + agents_from_center = agents_from_center, + center_from_agents = center_from_agents, + raster_from_world=raster_from_world, + agent_from_world=batch["centered_agent_from_world_tf"], + world_from_agent=centered_world_from_agent, + ) + if rasterize_mode=="point_sc": + d["hist_hm"] = hist_hm + d["fut_hm"] = fut_hm + + return d + +def parse_node_centric(batch: dict,rasterize_mode:str,): + maybe_pad_neighbor(batch) + fut_pos, fut_yaw, _, fut_mask = trajdata2posyawspeed(batch["agent_fut"]) + hist_pos, hist_yaw, hist_speed, hist_mask = trajdata2posyawspeed(batch["agent_hist"]) + curr_speed = hist_speed[..., -1] + curr_state = batch["curr_agent_state"] + curr_yaw = curr_state[:, -1] + curr_pos = curr_state[:, :2] + + # convert nuscenes types to l5kit types + # old_type = batch["agent_type"] + # agent_type = torch.zeros_like(old_type) + # agent_type[old_type < 0] = 0 + # agent_type[old_type ==[3,4]] = 2 + # agent_type[old_type ==1] = 3 + # agent_type[old_type ==2] = 2 + # agent_type[old_type ==5] = 4 + # mask out invalid extents + agent_hist_extent = batch["agent_hist_extent"] + agent_hist_extent[torch.isnan(agent_hist_extent)] = 0. + + neigh_hist_pos, neigh_hist_yaw, neigh_hist_speed, neigh_hist_mask = trajdata2posyawspeed(batch["neigh_hist"]) + neigh_fut_pos, neigh_fut_yaw, _, neigh_fut_mask = trajdata2posyawspeed(batch["neigh_fut"]) + if neigh_hist_speed.nelement() > 0: + neigh_curr_speed = neigh_hist_speed[..., -1] + else: + neigh_curr_speed = neigh_hist_speed.unsqueeze(-1) + # old_neigh_types = batch["neigh_types"] + # # convert nuscenes types to l5kit types + # neigh_types = torch.zeros_like(old_neigh_types) + # # neigh_types = torch.zeros_like(old_type) + # neigh_types[old_neigh_types < 0] = 0 + # neigh_types[old_neigh_types ==[3,4]] = 2 + # neigh_types[old_neigh_types ==1] = 3 + # neigh_types[old_neigh_types ==2] = 2 + # neigh_types[old_neigh_types ==5] = 4 + + # mask out invalid extents + neigh_hist_extents = batch["neigh_hist_extents"] + neigh_hist_extents[torch.isnan(neigh_hist_extents)] = 0. + + world_from_agents = torch.inverse(batch["agents_from_world_tf"]) + if batch["curr_agent_state"].shape[-1]==7: + world_yaw = batch["curr_agent_state"][...,6] + else: + assert batch["curr_agent_state"].shape[-1]==8 + world_yaw = torch.atan2(batch["curr_agent_state"][...,6],batch["curr_agent_state"][...,7]) + + # map-related + if batch["maps"] is not None: + map_res = batch["maps_resolution"][0] + h, w = batch["maps"].shape[-2:] + # TODO: pass env configs to here + raster_from_agent = torch.tensor([ + [map_res, 0, 0.125 * w], + [0, map_res, 0.5 * h], + [0, 0, 1] + ],device=curr_pos.device) + agent_from_raster = torch.inverse(raster_from_agent) + raster_from_agent = TensorUtils.unsqueeze_expand_at(raster_from_agent, size=batch["maps"].shape[0], dim=0) + agent_from_raster = TensorUtils.unsqueeze_expand_at(agent_from_raster, size=batch["maps"].shape[0], dim=0) + raster_from_world = torch.bmm(raster_from_agent, batch["agents_from_world_tf"]) + if neigh_hist_pos.nelement()>0: + all_hist_pos = torch.cat((hist_pos[:, None], neigh_hist_pos), dim=1) + all_hist_yaw = torch.cat((hist_yaw[:, None], neigh_hist_yaw), dim=1) + + all_extents = torch.cat((batch["agent_hist_extent"].unsqueeze(1),batch["neigh_hist_extents"]),1).max(dim=2)[0][...,:2] + all_hist_mask = torch.cat((hist_mask[:, None], neigh_hist_mask), dim=1) + else: + all_hist_pos = hist_pos[:, None] + all_hist_yaw = hist_yaw[:, None] + all_extents = batch["agent_hist_extent"].unsqueeze(1)[...,:2] + all_hist_mask = hist_mask[:, None] + if rasterize_mode=="none": + maps = batch["maps"] + elif rasterize_mode=="point": + maps = rasterize_agents( + batch["maps"], + all_hist_pos, + all_hist_yaw, + all_extents, + all_hist_mask, + raster_from_agent, + map_res + ) + elif rasterize_mode=="square": + maps = rasterize_agents_rec( + batch["maps"], + all_hist_pos, + all_hist_yaw, + all_extents, + all_hist_mask, + raster_from_agent, + map_res + ) + else: + raise Exception("unknown rasterization mode") + drivable_map = get_drivable_region_map(batch["maps"]) + else: + maps = None + drivable_map = None + raster_from_agent = None + agent_from_raster = None + raster_from_world = None + + extent_scale = 1.0 + d = dict( + image=maps, + drivable_map=drivable_map, + fut_pos=fut_pos, + fut_yaw=fut_yaw, + fut_mask=fut_mask, + hist_pos=hist_pos, + hist_yaw=hist_yaw, + hist_mask=hist_mask, + curr_speed=curr_speed, + centroid=curr_pos, + world_yaw=world_yaw, + type=batch["agent_type"], + extent=agent_hist_extent.max(dim=-2)[0] * extent_scale, + raster_from_agent=raster_from_agent, + agent_from_raster=agent_from_raster, + raster_from_world=raster_from_world, + agent_from_world=batch["agents_from_world_tf"], + world_from_agent=world_from_agents, + neigh_hist_pos=neigh_hist_pos, + neigh_hist_yaw=neigh_hist_yaw, + neigh_hist_mask=neigh_hist_mask, # dump hack to agree with l5kit's typo ... + neigh_curr_speed=neigh_curr_speed, + neigh_fut_pos=neigh_fut_pos, + neigh_fut_yaw=neigh_fut_yaw, + neigh_fut_mask=neigh_fut_mask, + neigh_types=batch["neigh_types"], + neigh_extents=neigh_hist_extents.max(dim=-2)[0] * extent_scale if neigh_hist_extents.nelement()>0 else None, + + ) + # if "agent_lanes" in batch: + # d["ego_lanes"] = batch["agent_lanes"] + + return d + +@torch.no_grad() +def parse_trajdata_batch(batch, rasterize_mode="point"): + if isinstance(batch,AgentBatch): + # Be careful here, without making a copy of vars(batch) we would modify the fields of AgentBatch. + batch = dict(vars(batch)) + d = parse_node_centric(batch,rasterize_mode) + elif isinstance(batch,SceneBatch): + batch = dict(vars(batch)) + d = parse_scene_centric(batch,rasterize_mode) + elif isinstance(batch,dict): + batch = dict(batch) + if "num_agents" in batch: + # scene centric + d = parse_scene_centric(batch,rasterize_mode) + + else: + # agent centric + d = parse_node_centric(batch,rasterize_mode) + + batch.update(d) + for k,v in batch.items(): + if isinstance(v,torch.Tensor): + batch[k]=v.nan_to_num(0) + batch.pop("agent_name", None) + batch.pop("robot_fut", None) + return batch + + +def get_modality_shapes(cfg: ExperimentConfig, rasterize_mode: str = "point"): + h = cfg.env.rasterizer.raster_size + if rasterize_mode=="none": + return dict(static=(3,h,h),dynamic=(0,h,h),image=(3,h,h)) + else: + num_channels = (cfg.history_num_frames + 1) + 3 + return dict(static=(3,h,h),dynamic=(cfg.history_num_frames + 1,h,h),image=(num_channels, h, h)) + + +def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B,N,T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + veh_mask = raw_types == int(AgentType["VEHICLE"]) + ped_mask = raw_types == int(AgentType["PEDESTRIAN"]) + + edges = torch.zeros([B,N,A,T,10],device=ego_trajectories.device) + edges[...,:3] = ego_trajectories.unsqueeze(2).repeat(1,1,A,1,1) + if agent_trajectories.ndim==4: + edges[...,3:6] = agent_trajectories.unsqueeze(1).repeat(1,N,1,1,1) + else: + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,1,2).repeat(1,N,A,T,1) + edges[...,8:] = agent_extents.reshape(B,1,A,1,2).repeat(1,N,1,T,1) + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + + + print("abc") + +def gen_EC_edges(ego_trajectories,agent_trajectories,ego_extents, agent_extents, raw_types,mask=None): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,A,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + mask (optional, torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + + B,A = ego_trajectories.shape[:2] + T = ego_trajectories.shape[-2] + + veh_mask = raw_types == int(AgentType["VEHICLE"]) + ped_mask = raw_types == int(AgentType["PEDESTRIAN"]) + + + if ego_trajectories.ndim==4: + edges = torch.zeros([B,A,T,10],device=ego_trajectories.device) + edges[...,:3] = ego_trajectories + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,1,T,1) + elif ego_trajectories.ndim==5: + + K = ego_trajectories.shape[2] + edges = torch.zeros([B,A*K,T,10],device=ego_trajectories.device) + edges[...,:3] = TensorUtils.join_dimensions(ego_trajectories,1,3) + edges[...,3:6] = agent_trajectories.repeat(1,K,1,1) + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A*K,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,K,T,1) + veh_mask = veh_mask.tile(1,K) + ped_mask = ped_mask.tile(1,K) + if mask is not None: + veh_mask = veh_mask*mask + ped_mask = ped_mask*mask + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def generate_edges( + raw_type, + extents, + pos_pred, + yaw_pred, + batch_first = False, +): + veh_mask = raw_type == int(AgentType["VEHICLE"]) + ped_mask = raw_type == int(AgentType["PEDESTRIAN"]) + + agent_mask = veh_mask | ped_mask + edge_types = ["VV", "VP", "PV", "PP"] + edges = {et: list() for et in edge_types} + for i in range(agent_mask.shape[0]): + agent_idx = torch.where(agent_mask[i] != 0)[0] + edge_idx = torch.combinations(agent_idx, r=2) + VV_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + VP_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + PV_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + PP_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + if pos_pred.ndim == 4: + edges_of_all_types = torch.cat( + ( + pos_pred[i, edge_idx[:, 0], :], + yaw_pred[i, edge_idx[:, 0], :], + pos_pred[i, edge_idx[:, 1], :], + yaw_pred[i, edge_idx[:, 1], :], + extents[i, edge_idx[:, 0]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + extents[i, edge_idx[:, 1]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[VV_idx]) + edges["VP"].append(edges_of_all_types[VP_idx]) + edges["PV"].append(edges_of_all_types[PV_idx]) + edges["PP"].append(edges_of_all_types[PP_idx]) + elif pos_pred.ndim == 5: + + edges_of_all_types = torch.cat( + ( + pos_pred[i, :, edge_idx[:, 0], :], + yaw_pred[i, :, edge_idx[:, 0], :], + pos_pred[i, :, edge_idx[:, 1], :], + yaw_pred[i, :, edge_idx[:, 1], :], + extents[i, None, edge_idx[:, 0], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + extents[i, None, edge_idx[:, 1], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[:, VV_idx]) + edges["VP"].append(edges_of_all_types[:, VP_idx]) + edges["PV"].append(edges_of_all_types[:, PV_idx]) + edges["PP"].append(edges_of_all_types[:, PP_idx]) + if batch_first: + for et, v in edges.items(): + edges[et] = pad_sequence(v, batch_first=True,padding_value=torch.nan) + else: + if pos_pred.ndim == 4: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=0) + elif pos_pred.ndim == 5: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=1) + return edges + + +def merge_scene_batches(scene_batches: List[SceneBatch], dt: float) -> SceneBatch: + assert scene_batches[0].history_pad_dir == PadDirection.BEFORE + assert all([b.agent_hist.shape[0] == 1 for b in scene_batches]), "only batch_size=1 is supported" + + # Convert everything to world coordinates + scene_batches = [b.apply_transform(b.centered_world_from_agent_tf) for b in scene_batches] + state_format = scene_batches[0].agent_hist._format + + # Not all agent might be present at all time steps, so we match them by name. + # Get unique names, use np.unique return_index and sort to preserve ordering. + agent_names = [np.array(b.agent_names[0]) for b in scene_batches] + all_agent_names = np.concatenate(agent_names) + _, idx = np.unique(all_agent_names, return_index=True) + unique_agent_names = all_agent_names[np.sort(idx)] + + num_agents = len(unique_agent_names) + hist_len = len(scene_batches) + fut_len = scene_batches[-1].agent_fut.shape[-2] + + # Create full history with nans, then replace them for each time step + agent_hist = torch.full((1, num_agents, hist_len, scene_batches[0].agent_hist.shape[-1]), dtype=scene_batches[0].agent_hist.dtype, fill_value=torch.nan) + agent_hist_extent = torch.full((1, num_agents, hist_len, 2), dtype=scene_batches[0].agent_hist_extent.dtype, fill_value=torch.nan) + agent_type = torch.full((1, num_agents), dtype=scene_batches[0].agent_type.dtype, fill_value=-1) + + for t, scene_batch in enumerate(scene_batches): + match_inds = np.argwhere(np.array(scene_batch.agent_names[0])[:, None] == unique_agent_names[None, :]) # n_current, n_all -> n_current, 2 + assert match_inds.shape[0] == len(scene_batch.agent_names[0]), "there should be only 1 unique match" + agent_hist[0, match_inds[:, 1], t, :] = scene_batch.agent_hist[0, match_inds[:, 0], -1, :] + agent_hist_extent[0, match_inds[:, 1], t, :] = scene_batch.agent_hist_extent[0, match_inds[:, 0], -1, :] + agent_type[0, match_inds[:, 1]] = scene_batch.agent_type[0, match_inds[:, 0]] + + # Dummy future, repeat last state + agent_fut = agent_hist[:, :, -1:, :].repeat_interleave(fut_len, dim=-2) + agent_fut_extent = agent_hist_extent[:, :, -1:, :].repeat_interleave(fut_len, dim=-2) + + # Create trajdata batch + merged_batch = SceneBatch( + data_idx=torch.tensor([torch.nan]), + scene_ts=scene_batches[0].scene_ts, + scene_ids=scene_batches[0].scene_ids, + dt=torch.tensor([dt]), + num_agents=torch.tensor([num_agents]), + agent_type=agent_type, + centered_agent_state=scene_batches[0].centered_agent_state, + agent_names=[list(unique_agent_names)], + agent_hist=StateTensor.from_array(agent_hist, state_format), + agent_hist_extent=agent_hist_extent, + agent_hist_len=torch.tensor([[hist_len] * num_agents]), # len includes current state + agent_fut=StateTensor.from_array(agent_fut, state_format), + agent_fut_extent=agent_fut_extent, + agent_fut_len=torch.from_numpy(np.array([[fut_len] * num_agents])), + robot_fut=None, + robot_fut_len=None, + map_names=scene_batches[0].map_names, + maps=scene_batches[0].map_names, + maps_resolution=scene_batches[0].maps_resolution, + vector_maps=scene_batches[0].vector_maps, + rasters_from_world_tf=scene_batches[0].rasters_from_world_tf, + centered_agent_from_world_tf=scene_batches[0].centered_agent_from_world_tf, + centered_world_from_agent_tf=scene_batches[0].centered_world_from_agent_tf, + history_pad_dir=PadDirection.BEFORE, + extras=dict(scene_batches[0].extras), + ) + + return merged_batch diff --git a/diffstack/utils/tree.py b/diffstack/utils/tree.py new file mode 100644 index 0000000..845fdde --- /dev/null +++ b/diffstack/utils/tree.py @@ -0,0 +1,102 @@ +import itertools +from collections import defaultdict +import networkx as nx + +class Tree(object): + + def __init__(self, content, parent, depth): + self.content = content + self.children = list() + self.parent = parent + if parent is not None: + parent.expand(self) + self.depth = depth + self.attribute = dict() + + def expand(self, child): + self.children.append(child) + + def expand_set(self, children): + self.children += children + + def isroot(self): + return self.parent is None + + def isleaf(self): + return len(self.children) == 0 + + def get_subseq_trajs(self): + return [child.traj for child in self.children] + + + def get_all_leaves(self,leaf_set=[]): + if self.isleaf(): + leaf_set.append(self) + else: + for child in self.children: + leaf_set = child.get_all_leaves(leaf_set) + return leaf_set + def get_label(self): + raise NotImplementedError + + @staticmethod + def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True): + assert obj.depth<=depth + if nodes is None: + nodes = defaultdict(lambda: list()) + if obj.depth==depth: + nodes[depth].append(obj) + return nodes, True + else: + if obj.isleaf(): + return nodes, False + + else: + flag = False + children_flags = dict() + for child in obj.children: + nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes) + children_flags[child] = child_flag + flag = flag or child_flag + if trim_short_branch: + obj.children = [child for child in obj.children if children_flags[child]] + if flag: + nodes[obj.depth].append(obj) + return nodes, flag + + @staticmethod + def get_children(obj): + if isinstance(obj, Tree): + return obj.children + elif isinstance(obj, list): + children = [node.children for node in obj] + children = list(itertools.chain.from_iterable(children)) + return children + else: + raise TypeError("obj must be a TrajTree or a list") + + def as_network(self): + G = nx.Graph() + G.add_node(self.get_label()) + for child in self.children: + G = nx.union(G,child.as_network()) + G.add_edge(self.get_label(),child.get_label()) + return G + + def plot(self): + G = self.as_network() + + pos = nx.nx_agraph.pygraphviz_layout(G, prog="dot") + nx.draw(G, pos,with_labels = True) + # nx.draw(G, with_labels = True) + + + + +def depth_first_traverse(tree:Tree,func, visited:dict, result): + result = func(tree,result) + visited[tree] = True + for child in tree.children: + if not (child in visited and visited[child]): + result, visited = depth_first_traverse(child, func, visited, result) + return result, visited \ No newline at end of file diff --git a/diffstack/utils/utils.py b/diffstack/utils/utils.py new file mode 100644 index 0000000..7e4a524 --- /dev/null +++ b/diffstack/utils/utils.py @@ -0,0 +1,846 @@ +import os +import random +import torch +import time +import numpy as np +import pickle +import dill +import collections.abc + +from datetime import timedelta +from collections import defaultdict, OrderedDict +from scipy.interpolate import interp1d +from typing import Dict, Union, Tuple, Any, Optional, Iterable, List +from torch.utils.data._utils.collate import default_collate + +from nuscenes.map_expansion import arcline_path_utils +from nuscenes.map_expansion.map_api import NuScenesMap + +from diffstack.utils.geometry_utils import batch_rotate_2D +# Expose a couple of util functions defined in different submodules. +from trajdata.utils.arr_utils import angle_wrap, batch_select +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +container_abcs = collections.abc + + +# Distributed +import torch.distributed +try: + import gpu_affinity + USE_GPU_AFFINITY = True +except: + USE_GPU_AFFINITY = False + + +def initialize_torch_distributed(local_rank: int): + if torch.cuda.is_available(): + backend = 'nccl' + # Set gpu affinity so that the optimal memory segment is used for multi-gpu training + # https://gitlab-master.nvidia.com/dl/gwe/gpu_affinity + if USE_GPU_AFFINITY: + gpu_affinity.set_affinity(local_rank, int(os.environ["WORLD_SIZE"])) + else: + backend = 'gloo' + + torch.distributed.init_process_group(backend=backend, + init_method='env://', + # default timeout torch.distributed.default_pg_timeout=1800 (sec, =30mins) + # increase timeout for datacaching where workload for different gpus can be very different + timeout=timedelta(hours=10)) # 10h + + +def set_all_seeds(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def prepeare_torch_env(rank, hyperparams): + if torch.cuda.is_available() and hyperparams.run["device"] != 'cpu': + hyperparams.run["device"] = f'cuda:{rank}' + torch.cuda.set_device(rank) + else: + hyperparams.run["device"] = f'cpu' + + if hyperparams.run["seed"] is not None: + set_all_seeds(hyperparams.run["seed"]) + +class CudaTimer(object): + def __init__(self, enabled=True): + self.enabled=enabled + self.timers=defaultdict(list) + + @staticmethod + def cuda_sync_maybe(): + try: + torch.cuda.synchronize() + except: + pass + + def start(self, name): + if self.enabled: + self.cuda_sync_maybe() + self.timers[name].append(time.time()) + + def end(self, name): + if self.enabled: + assert name in self.timers + self.cuda_sync_maybe() + self.timers[name].append(time.time()) + + def print(self, names=None): + if self.enabled: + if names is None: + names = self.timers.keys() + s = "Timer " + " ".join([f"{name}={self.timers[name][-1]-self.timers[name][0]:.3f}" for name in names]) + print (s) + self.timers.clear() + + +def merge_dicts_with_prefix(prefix_dict: Dict[str, Dict[str, Any]], separator: str = ".") -> Dict[str, Any]: + output = {} + for prefix, dict_instance in prefix_dict.items(): + for k, v in dict_instance.items(): + output[f"{prefix}{separator}{k}"] = v + return output + + +def batch_derivative_of(states, dt = 1.): + """ + states: [..., T, state_dim] + dt: time difference between states in input trajectory + """ + diff = states[..., 1:, :] - states[..., :-1, :] + # Add first state derivative + if isinstance(states, torch.Tensor): + diff = torch.cat((diff[..., :1, :], diff), dim=-2) + else: + diff = np.concatenate((diff[..., :1, :], diff), axis=-2) + return diff / dt + + +def subsample_future(x: torch.Tensor, new_horizon: int, current_horizon: int): + return subsample_traj(x, new_horizon, current_horizon, is_future=True) + + +def subsample_history(x: torch.Tensor, new_horizon: int, current_horizon: int): + return subsample_traj(x, new_horizon, current_horizon, is_future=False) + + +def subsample_traj(x: torch.Tensor, new_horizon: int, current_horizon: int, is_future: bool = True): + """x: [..., planh+1, :]""" + assert x.shape[-2] == current_horizon + 1 + if current_horizon != new_horizon: + if new_horizon == 0: + subsample_inds = [0] + else: + assert current_horizon % new_horizon == 0, f"planning horizon ({new_horizon}) needs to be a multiple of prediction horizon ({current_horizon})" + subsample_gap = current_horizon // new_horizon + subsample_inds = list(range(0, current_horizon+1, subsample_gap)) # for gap=2 [0,1,2,3,4,5] --> [0,2,4] + if not is_future: + # for gap=2 [0,1,2,3,4,5] --> [0,2,4] --> [-1, -3, -5] --> [-5, -3, -1] (equivalent --> [1, 3, 5]) + subsample_inds = [-ind-1 for ind in subsample_inds] + subsample_inds.reverse() + + assert len(subsample_inds) == new_horizon+1 + if x.ndim == 2: + return x[subsample_inds] + else: + return x.transpose(-2, 0)[subsample_inds].transpose(-2, 0) + else: + return x + + +def normalize_angle(h): + return (h + np.pi) % (2.0 * np.pi) - np.pi + + +def traj_xyhvv_to_pred(traj, dt): + # Input: [x, y, h, vx, vy] + # Output prediction state: ['x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta'] + x, y, h, vx, vy = np.split(traj, 5, axis=-1) + ax = batch_derivative_of(vx, dt=dt) + ay = batch_derivative_of(vy, dt=dt) + pred_state = np.concatenate(( + x, y, vx, vy, ax, ay, np.sin(h), np.cos(h) + ), axis=-1) + return pred_state + + +def traj_pred_to_xyhvv(traj): + # Input prediction state: ['x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta'] + # Output: [x, y, h, vx, vy] + x, y, vx, vy, ax, ay, sinh, cosh = np.split(traj, 8, axis=-1) + h = np.arctan2(sinh, cosh) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xy_to_xyh(traj: Union[np.ndarray, torch.Tensor]): + xy = traj + dxy = xy[..., 1:, :2] - xy[..., :-1, :2] + if isinstance(traj, torch.Tensor): + h = torch.atan2(dxy[..., 1], dxy[..., 0])[..., None] # TODO invalid for near-zero velocity + h = torch.concat((h, h[..., -1:, :]), dim=-2) # extend time + return torch.concat((xy, h), dim=-1) + else: + h = np.arctan2(dxy[..., 1], dxy[..., 0])[..., None] # TODO invalid for near-zero velocity + h = np.concatenate((h, h[..., -1:, :]), axis=-2) # extend time + return np.concatenate((xy, h), axis=-1) + + +def traj_xyh_to_xyhv(traj: Union[np.ndarray, torch.Tensor], dt: float): + xyhvv = traj_xyh_to_xyhvv(traj, dt) + return traj_xyhvv_to_xyhv(xyhvv) + + +def traj_xyh_to_xyhvv(traj: Union[np.ndarray, torch.Tensor], dt: float): + if isinstance(traj, torch.Tensor): + x, y, h = torch.split(traj, 1, dim=-1) + vx = batch_derivative_of(x, dt) + vy = batch_derivative_of(y, dt) + return torch.concat((x, y, h, vx, vy), dim=-1) + else: + x, y, h = np.split(traj, 3, axis=-1) + vx = batch_derivative_of(x, dt) + vy = batch_derivative_of(y, dt) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xyhv_to_xyhvv(traj): + x, y, h, v = np.split(traj, 4, axis=-1) + vx = v * np.cos(h) + vy = v * np.sin(h) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xyhvv_to_xyhv(traj: Union[np.ndarray, torch.Tensor]): + # Use only the forward velocity component and ignore sideway velocity. + if isinstance(traj, torch.Tensor): + x, y, h, vx, vy = torch.split(traj, 1, dim=-1) + v_xy = batch_rotate_2D(torch.stack((vx, vy),-1), -h) # forward and sideway velocity + v = v_xy[...,0] + return torch.concat((x, y, h, v), dim=-1) + else: + x, y, h, vx, vy = np.split(traj, 5, axis=-1) + v_xy = batch_rotate_2D(np.stack((vx, vy),-1), -h) # forward and sideway velocity + v = v_xy[...,0] + return np.concatenate((x, y, h, v), axis=-1) + + +def closest_lane_state(global_state: np.ndarray, nusc_map: NuScenesMap): + nearest_lane = nusc_map.get_closest_lane(x=global_state[0], y=global_state[1]) + lane_rec = nusc_map.get_arcline_path(nearest_lane) + closest_state, _ = arcline_path_utils.project_pose_to_lane(global_state, lane_rec) + return closest_state + +def lane_frenet_features_simple(ego_state: np.ndarray, lane_states: np.ndarray, plot=False): + """Taking the equation from the "Line defined by two points" section of + https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line + as well as this answer: https://stackoverflow.com/a/6853926 + """ + if ego_state.ndim == 2: + # Recursive callse, iterate over batch/time dimension + assert ego_state.shape[0] == lane_states.shape[0] + out = [lane_frenet_features_simple(ego_state[i], lane_states[i], plot=plot) for i in range(ego_state.shape[0])] + return np.stack(out, axis=0) + + assert lane_states.ndim == 2 and ego_state.ndim == 1 + + # Simplified + i = np.argmin(np.square(lane_states[:, :2] - ego_state[None, :2]).sum(-1)) + lane_pt = lane_states[i] + lane_pt_xy = lane_pt[:2] + lane_pt_h = lane_pt[2] + v = np.array([np.cos(lane_pt_h), np.sin(lane_pt_h)]) + proj_len = ((ego_state[:2] - lane_pt_xy) * v).sum(-1) # equivalent of dot(xy-vect_xy, v) + proj_onto_lane = v * proj_len + lane_pt_xy + + # Debug + if plot: + import matplotlib.pyplot as plt + fig = plt.figure() + pts = np.stack([lane_pt_xy - 20*v, lane_pt_xy + 1*v]) + plt.plot(pts[:, 0], pts[:, 1], label="lane line") + plt.scatter(ego_state[0], ego_state[1], label='ego') + plt.scatter(proj_onto_lane[0], proj_onto_lane[1], label='proj') + plt.gca().set_aspect('equal') + plt.legend() + plt.show() + + return np.array([proj_onto_lane[0], proj_onto_lane[1], lane_pt_h]) + + +def closest_lane(ego_xy: torch.Tensor, lane_points: torch.Tensor): + ind = torch.argmin(torch.square(lane_points[..., :2] - ego_xy[..., :2].unsqueeze(-2)).sum(-1), axis=-1) + # Workaround for lack of gather_nd and no broadcasting of gather + # TODO look for faster implementation of gather_nd + lane_comps = torch.unbind(lane_points, axis=-1) + lane_comps = [torch.gather(x, x.ndim-1, ind.unsqueeze(-1)) for x in lane_comps] + return torch.cat(lane_comps, axis=-1) + + +def closest_lane_np(ego_xy: np.ndarray, lane_points_list: Iterable[np.ndarray]): + state_to_lane_dist2 = [ + np.square(lane_points[..., :2] - ego_xy[np.newaxis,..., :2]).sum(-1).min(-1) + for lane_points in lane_points_list] + return np.argmin(state_to_lane_dist2) if len(state_to_lane_dist2)>0 else None + + +def get_pointgoal_from_onroute_lanes(ego_state_xyhv: np.ndarray, lanes_xyh: List[np.ndarray], dt: float, future_len: int, max_vel: float = 29.0, target_acc: float = 1.5) -> np.ndarray: + # TODO we should move this logic to the planner + + assert ego_state_xyhv.ndim == 1 and ego_state_xyhv.shape[-1] == 4 + ego_x, ego_y, ego_h, ego_v = ego_state_xyhv + + if len(lanes_xyh) == 0: + # No lanes. Set goal as current location. + goal_point_xyh = np.stack((ego_x, ego_y, ego_h), axis=-1) + ref_traj_xyh = np.repeat(goal_point_xyh[None], future_len+1, axis=0) + print ("WARNING: no lane for inferring goal") + return ref_traj_xyh, goal_point_xyh + + # Target final velocity, based on accelearting but remaining under max_vel limit + target_v = np.minimum(ego_v + target_acc * future_len * dt, max_vel) + avg_v = 0.5 * (target_v + ego_v) + target_pathlen = avg_v * future_len * dt + + closest_lane_ind = closest_lane_np(ego_state_xyhv[:3], lanes_xyh) + closest_lane = lanes_xyh[closest_lane_ind] + + state_to_lane_dist2 = np.square(closest_lane[:, :2] - ego_state_xyhv[None, :2]).sum(-1) + closest_point_ind = np.argmin(state_to_lane_dist2) + + # Assume polyline is ordered by the lane direction. + if closest_point_ind < closest_lane.shape[0]-1: + future_lane = closest_lane[closest_point_ind:] + # Get distance along lane. + # TODO compute distance from current state, not first lane point + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + else: + # TODO again we should compute distance from current state, negative for first state, positive for second + future_lane = closest_lane[closest_point_ind-1:] + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + + dist_along_lane = np.cumsum(step_len) + + # Find connecting lane + del closest_lane + while dist_along_lane[-1] < target_pathlen: + # TODO manually find continuation of lane, i.e. the lane with first point closest to last point of our lane. + first_lane_points = np.array([lane[0] for lane in lanes_xyh]) + d = np.linalg.norm(first_lane_points[:, :2] - future_lane[-1, None, :2], axis=-1) + closest_lane_ind = np.argmin(d) + + if d[closest_lane_ind] > 3: # 3m maximum to treat them as connected lanes + # No more lanes, we need to stop at the end of current lane. + break + + future_lane = np.concatenate((future_lane, lanes_xyh[closest_lane_ind]), axis=0) + # TODO compute distance from current state, not first lane point + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + dist_along_lane = np.cumsum(step_len) + + del closest_lane_ind + + # Find lane point closest to our target distance. + # TODO this could be done more efficiently, use lane util functions + target_lane_point_ind = np.argmin(np.abs(dist_along_lane - target_pathlen)) + goal_point_xyh = future_lane[target_lane_point_ind] + + # Recomput target pathlen based on goal point. + target_pathlen = dist_along_lane[target_lane_point_ind] + avg_v = target_pathlen / (future_len * dt) + target_v = 2 * avg_v - ego_v + a = (target_v - ego_v) / (future_len * dt) + + # For each future time t, what is the intended length of traversed path. + delta_t = np.arange(future_len+1) * dt + delta_pathlen_t = (delta_t * a * 0.5 + ego_v) * delta_t + + # Interpolate lane at these delta lenth points. + # TODO we need to unwrap angles, do interpolation, and then wrap them again. + interp_fn = interp1d(dist_along_lane, future_lane, bounds_error=False, assume_sorted=True, axis=0) # nan for extrapolation + # interp_fn = interp1d(dist_along_lane, future_lane, fill_value="extrapolate", assume_sorted=True, axis=0) + + ref_traj_xyh = interp_fn(delta_pathlen_t) + + return ref_traj_xyh, goal_point_xyh + + +def lat_long_distances(x: torch.Tensor, y: torch.Tensor, vect_x: torch.Tensor, vect_y: torch.Tensor, vect_h: torch.Tensor): + sin_lane_h = torch.sin(vect_h) + cos_lane_h = torch.cos(vect_h) + d_long = (x - vect_x) * cos_lane_h + (y - vect_y) * sin_lane_h + d_lat = -(x - vect_x) * sin_lane_h + (y - vect_y) * cos_lane_h + return d_lat, d_long + + +def lane_frenet_features(ego_state: np.ndarray, lane_states: np.ndarray): + """Taking the equation from the "Line defined by two points" section of + https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line + as well as this answer: https://stackoverflow.com/a/6853926 + """ + x1s = lane_states[:-1, 0] + y1s = lane_states[:-1, 1] + h1s = lane_states[:-1, 2] + + x2s = lane_states[1:, 0] + y2s = lane_states[1:, 1] + h2s = lane_states[1:, 2] + + A = ego_state[0] - x1s + B = ego_state[1] - y1s + + C = x2s - x1s + D = y2s - y1s + + dot = A * C + B * D + len_sq = C * C + D * D + params = np.ma.masked_invalid(np.divide(dot, len_sq, out=np.full_like(dot, np.nan), where=np.abs(len_sq) >= 1e-3)) + + if (params < 0).all(): + seg_idx = np.argmax(params) + lane_x = x1s[seg_idx] + lane_y = y1s[seg_idx] + lane_h = h1s[seg_idx] + elif (params > 1).all(): + seg_idx = np.argmin(params) + lane_x = x2s[seg_idx] + lane_y = y2s[seg_idx] + lane_h = h2s[seg_idx] + else: + seg_idx = np.argmin(np.abs(params)) + lane_x = x1s[seg_idx] + params[seg_idx] * C[seg_idx] + lane_y = y1s[seg_idx] + params[seg_idx] * D[seg_idx] + lane_h = h1s[seg_idx] + params[seg_idx] * (h2s[seg_idx] - h1s[seg_idx]) + + # plot_lane_frenet(lane_states, ego_state, np.array([xx, yy, hh]), seg_idx) + return lane_x, lane_y, lane_h, seg_idx + + +def np_rbf(input: np.ndarray, center: Union[np.ndarray, float] = 0.0, scale: Union[np.ndarray, float] = 1.0): + """Assuming here that input is of shape (..., D), with center and scale of broadcastable shapes. + """ + return np.exp(-0.5*np.square(input - center).sum(-1)/scale) + + +def pt_rbf(input: torch.Tensor, center: Union[torch.Tensor, float] = 0.0, scale: Union[torch.Tensor, float] = 1.0): + """Assuming here that input is of shape (..., D), with center and scale of broadcastable shapes. + """ + return torch.exp(-0.5*torch.square(input - center).sum(-1)/scale) + + +def wrap(angles: Union[torch.Tensor, np.ndarray]): + return (angles + np.pi) % (2 * np.pi) - np.pi + +def angle_wrap(angles: Union[torch.Tensor, np.ndarray]): + return wrap(angles) + + +# Custom torch functions that support jit.trace +class _tracable_exp_fn(torch.autograd.Function): + def forward(ctx, x: torch.Tensor): + ctx.save_for_backward(x) + return x.exp() + + def backward(ctx, dl_dx): + x, = ctx.saved_tensors + return _tracable_exp_fn.apply(x) * dl_dx +tracable_exp = _tracable_exp_fn.apply + +tracable_sqrt = lambda x: torch.pow(x, 0.5) + +tracable_norm = lambda x, dim: tracable_sqrt(x.square().sum(dim=dim)) + +def tracable_rbf(input: torch.Tensor, center: Union[torch.Tensor, float] = 0.0, scale: Union[torch.Tensor, float] = 1.0): + """Assuming here that input is of shape (..., D), with center and scale of broadcastable shapes. + """ + return tracable_exp(-0.5*tracable_sqrt(input - center).sum(-1)/scale) + + + +def ensure_length_nd(x, u, extra_info: Optional[Dict[str, torch.Tensor]] = None): + if extra_info is not None: + ep_lens = extra_info['ep_lengths'] + + x_reshaped = x[..., :ep_lens+1, :] + u_reshaped = u[..., :ep_lens+1, :] + # Again, this is one more timesteps than there should be for u, + # the last is all zero, and is ignored in the creation of B later. + + return x_reshaped, u_reshaped + else: + return x, u + + +def prediction_diversity(y_dists): + # Variance of predictions at the final time step + mus = y_dists.mus # 1, b, T, N, 2 + log_pis = y_dists.log_pis # 1, b, T, N + + xy = mus.squeeze(0)[:, -1] # b, N, 2 + probs = torch.exp(log_pis.squeeze(0)[:, -1]) # b, N + # weighted mean final xy + xy_mean = (xy * probs.unsqueeze(-1)).sum(dim=1) # b, 2 + dist_from_mean = torch.linalg.norm(xy_mean.unsqueeze(1) - xy, dim=-1) # b, N + # variance of eucledian distances from the (weighted) mean prediction + # TODO this metric doesnt have a proper statistical interpretation + # To make it interpretable follow something like + # similarity measure in DPP https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9387598 + # diversity energy https://openaccess.thecvf.com/content/ICCV2021/papers/Cui_LookOut_Diverse_Multi-Future_Prediction_and_Planning_for_Self-Driving_ICCV_2021_paper.pdf + var = (probs * torch.square(dist_from_mean)).sum(dim=-1) + + return var + + +def convert_state_pred2plan(x_pred: Union[torch.Tensor, np.ndarray]): + """ Transform + + Prediction input: 'x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta' + Planning input: x, y, theta, v + """ + if isinstance(x_pred, torch.Tensor): + x_plan = torch.stack([ + x_pred[..., 0], # x + x_pred[..., 1], # y + torch.atan2(x_pred[..., 6], x_pred[..., 7]), # theta + torch.linalg.norm(x_pred[..., 2:4], dim=-1), # v + ], dim=-1) + else: + x_plan = np.stack([ + x_pred[..., 0], # x + x_pred[..., 1], # y + np.arctan2(x_pred[..., 6], x_pred[..., 7]), # theta + np.linalg.norm(x_pred[..., 2:4], axis=-1), # v + ], axis=-1) + return x_plan + + +def gmms_from_single_futures(x_xyhv: torch.Tensor, dt: float): + assert x_xyhv.ndim == 4 # N, b, T, 4 + assert x_xyhv.shape[-1] == 4 # xyhv + ph = x_xyhv.shape[-2] + + mus = x_xyhv.unsqueeze(3) # (N, b, T, 1, 4) + log_pis = torch.zeros(mus.shape[:-1], dtype=mus.dtype, device=mus.device) + log_sigmas = torch.log((torch.arange(1, ph+1, dtype=mus.dtype, device=mus.device) * dt)**2*2) + log_sigmas = log_sigmas.reshape(1, 1, ph, 1, 1).repeat((x_xyhv.shape[0], x_xyhv.shape[1], 1, 1, 2)) + corrs = 0. * torch.ones(mus.shape[:-1], dtype=mus.dtype, device=mus.device) # TODO not sure what is reasonable + + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + return y_dists + + +def gmm_concat_as_modes(gmms: Iterable[GMM2D], probs: Iterable[float]): + gmm_joint = GMM2D( + log_pis=torch.concat([gmm.log_pis + np.log(p) for gmm, p in zip(gmms, probs)], dim=-1), + mus=torch.concat([gmm.mus for gmm in gmms], dim=-2), + log_sigmas=torch.concat([gmm.sigmas for gmm in gmms], dim=-2), + corrs=torch.concat([gmm.corrs for gmm in gmms], dim=-1), + ) + return gmm_joint + + +def gmm_concat_as_agents(gmms: Iterable[GMM2D]): + gmm_joint = GMM2D( + log_pis=torch.concat([gmm.log_pis for gmm in gmms], dim=0), + mus=torch.concat([gmm.mus for gmm in gmms], dim=0), + log_sigmas=torch.concat([gmm.log_sigmas for gmm in gmms], dim=0), + corrs=torch.concat([gmm.corrs for gmm in gmms], dim=0), + ) + return gmm_joint + + +def gmm_extend(gmm: GMM2D, num_modes: int): + gmm = GMM2D( + log_pis=torch.nn.functional.pad(gmm.log_pis, (0, num_modes), mode="constant", value=-np.inf), + mus=torch.nn.functional.pad(gmm.mus, (0, 0, 0, num_modes), mode="constant", value=np.nan), + log_sigmas=torch.nn.functional.pad(gmm.log_sigmas, (0, 0, 0, num_modes), mode="constant", value=np.nan), + corrs=torch.nn.functional.pad(gmm.corrs, (0, num_modes), mode="constant", value=0.), + ) + return gmm + + +def move_list_element_to_front(a: list, i: int) -> list: + a = [a[i]] + [a[j] for j in range(len(a)) if j != i] + return a + +def soft_min(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((y-x)/2) + exp2 = expfun((x-y)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_max(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((x-y)/2) + exp2 = expfun((y-x)/2) + return (exp1*x+exp2*y)/(exp1+exp2) +def soft_sat(x,x_min=None,x_max=None,gamma=5): + if x_min is None and x_max is None: + return x + elif x_min is None and x_max is not None: + return soft_min(x,x_max,gamma) + elif x_min is not None and x_max is None: + return soft_max(x,x_min,gamma) + else: + if isinstance(x_min,torch.Tensor) or isinstance(x_min,np.ndarray): + assert (x_max>x_min).all() + else: + assert x_max>x_min + xc = x - (x_min+x_max)/2 + if isinstance(x,torch.Tensor): + return xc/(torch.pow(1+torch.pow(torch.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + elif isinstance(x,np.ndarray): + return xc/(np.power(1+np.power(np.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + else: + raise Exception("data type not supported") + + +def recursive_dict_list_tuple_apply(x, type_func_dict, ignore_if_unspecified=False): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert list not in type_func_dict + assert tuple not in type_func_dict + assert dict not in type_func_dict + assert torch.nn.ParameterDict not in type_func_dict + assert torch.nn.ParameterList not in type_func_dict + + if isinstance(x, (dict, OrderedDict, torch.nn.ParameterDict)): + new_x = ( + OrderedDict() + if isinstance(x, OrderedDict) + else dict() + ) + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) + return new_x + elif isinstance(x, (list, tuple, torch.nn.ParameterList)): + ret = [recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) for v in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + if ignore_if_unspecified: + return x + else: + raise NotImplementedError("Cannot handle data type %s" % str(type(x))) + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert begin_axis < end_axis + assert begin_axis >= 0 + assert end_axis <= len(x.shape) + assert isinstance(target_dims, (tuple, list)) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i >= end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def yaw_from_pos(pos: torch.Tensor, dt, yaw_correction_speed=0.): + """ + Compute yaws from position sequences. Optionally suppress yaws computed from low-velocity steps + + Args: + pos (torch.Tensor): sequence of positions [..., T, 2] + dt (float): delta timestep to compute speed + yaw_correction_speed (float): zero out yaw change when the speed is below this threshold (noisy heading) + + Returns: + accum_yaw (torch.Tensor): sequence of yaws [..., T-1, 1] + """ + + pos_diff = pos[..., 1:, :] - pos[..., :-1, :] + yaw = torch.atan2(pos_diff[..., 1], pos_diff[..., 0]) + delta_yaw = torch.cat((yaw[..., [0]], yaw[..., 1:] - yaw[..., :-1]), dim=-1) + speed = torch.norm(pos_diff, dim=-1) / dt + delta_yaw[speed < yaw_correction_speed] = 0. + accum_yaw = torch.cumsum(delta_yaw, dim=-1) + return accum_yaw[..., None] + + +def all_gather(data): + """Run all_gather on arbitrary picklable data (not necessarily tensors) + + Parameters + ---------- + data: any picklable object + + Returns + -------- + list[data] + List of data gathered from each rank + """ + world_size = torch.distributed.get_world_size() + + if world_size == 1: + return [data] + + # Serialize to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # Obtain Tensor size of each rank + local_size = torch.IntTensor([tensor.numel()]).to("cuda") + size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] + torch.distributed.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # Receive Tensor from all ranks + # We pad the tensor because torch all_gather does not support gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size, )).to("cuda")) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda") + tensor = torch.cat((tensor, padding), dim=0) + torch.distributed.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +# Wrapper around dict to identify batchable dict data. +class batchable_dict(dict): + pass + +class batchable_list(list): + pass + +class batchable_nonuniform_tensor(torch.Tensor): + pass + + +def restore(data): + """ + In case we dilled some structures to share between multiple process this function will restore them. + If the data input are not bytes we assume it was not dilled in the first place + + :param data: Possibly dilled data structure + :return: Un-dilled data structure + """ + if type(data) is bytes: + return dill.loads(data) + return data + + +def collate(batch): + if len(batch) == 0: + return batch + elem = batch[0] + if elem is None: + return None + elif isinstance(elem, str) or elem.__class__.__name__ == "batchable_list" or elem.__class__.__name__ == "batchable_nonuniform_tensor": + # TODO isinstance(elem, batchable_nonuniform_tensor) is never true, perhaps some import path comparison issue + return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch + elif isinstance(elem, container_abcs.Sequence): + if len(elem) == 4: # We assume those are the maps, map points, headings and patch_size + scene_map, scene_pts, heading_angle, patch_size = zip(*batch) + if heading_angle[0] is None: + heading_angle = None + else: + heading_angle = torch.Tensor(heading_angle) + map = scene_map[0].get_cropped_maps_from_scene_map_batch(scene_map, + scene_pts=torch.Tensor(scene_pts), + patch_size=patch_size[0], + rotation=heading_angle) + return map + transposed = zip(*batch) + return [collate(samples) for samples in transposed] + elif elem.__class__.__name__ == "batchable_dict": + # We dill the dictionary for the same reason as the neighbors structure (see below). + # Unlike for neighbors where we keep a list, here we collate elements recursively + data_dict = {key: collate([d[key] for d in batch]) for key in elem} + return dill.dumps(data_dict) if torch.utils.data.get_worker_info() else data_dict + elif isinstance(elem, container_abcs.Mapping): + # We have to dill the neighbors structures. Otherwise each tensor is put into + # shared memory separately -> slow, file pointer overhead + # we only do this in multiprocessing + neighbor_dict = {key: [d[key] for d in batch] for key in elem} + return dill.dumps(neighbor_dict) if torch.utils.data.get_worker_info() else neighbor_dict + try: + return default_collate(batch) + except RuntimeError: + # This happens when tensors are not of the same shape. + return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch + + +def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0): + return v.reshape( + torch.Size([1] * n_dim_to_prepend) + + v.shape + + torch.Size([1] * n_dim_to_append)) + + +def block_diag(m): + """ + Make a block diagonal matrix along dim=-3 + EXAMPLE: + block_diag(torch.ones(4,3,2)) + should give a 12 x 8 matrix with blocks of 3 x 2 ones. + Prepend batch dimensions if needed. + You can also give a list of matrices. + :type m: torch.Tensor, list + :rtype: torch.Tensor + """ + if type(m) is list: + m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3) + + d = m.dim() + n = m.shape[-3] + siz0 = m.shape[:-3] + siz1 = m.shape[-2:] + m2 = m.unsqueeze(-2) + eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1) + return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n)) + +def removeprefix(line, prefix): + if line.startswith(prefix): + line_new = line[len(prefix):] + return line_new \ No newline at end of file diff --git a/diffstack/utils/vis_utils.py b/diffstack/utils/vis_utils.py new file mode 100644 index 0000000..b448b25 --- /dev/null +++ b/diffstack/utils/vis_utils.py @@ -0,0 +1,394 @@ +import numpy as np +from PIL import Image, ImageDraw +from collections import defaultdict +from typing import List, Optional, Tuple, Dict +from trajdata.maps.vec_map import VectorMap + +from l5kit.geometry import transform_points +from l5kit.rasterization.render_context import RenderContext +from l5kit.configs.config import load_metadata +from trajdata.maps import RasterizedMap + +from diffstack.utils.tensor_utils import map_ndarray +from diffstack.utils.geometry_utils import get_box_world_coords_np +import diffstack.utils.tensor_utils as TensorUtils +import os +import glob +from bokeh.models import ColumnDataSource, GlyphRenderer +from bokeh.plotting import figure, curdoc +import bokeh +from bokeh.io import export_png +from trajdata.utils.arr_utils import ( + transform_coords_2d_np, + batch_nd_transform_points_pt, + batch_nd_transform_points_np, +) +from trajdata.utils.vis_utils import draw_map_elems + +COLORS = { + "agent_contour": "#247BA0", + "agent_fill": "#56B1D8", + "ego_contour": "#911A12", + "ego_fill": "#FE5F55", +} + + +def agent_to_raster_np(pt_tensor, trans_mat): + pos_raster = transform_points(pt_tensor[None], trans_mat)[0] + return pos_raster + + +def draw_actions( + state_image, + trans_mat, + pred_action=None, + pred_plan=None, + pred_plan_info=None, + ego_action_samples=None, + plan_samples=None, + action_marker_size=3, + plan_marker_size=8, +): + im = Image.fromarray((state_image * 255).astype(np.uint8)) + draw = ImageDraw.Draw(im) + + if pred_action is not None: + raster_traj = agent_to_raster_np( + pred_action["positions"].reshape(-1, 2), trans_mat + ) + for point in raster_traj: + circle = np.hstack([point - action_marker_size, point + action_marker_size]) + draw.ellipse(circle.tolist(), fill="#FE5F55", outline="#911A12") + if ego_action_samples is not None: + raster_traj = agent_to_raster_np( + ego_action_samples["positions"].reshape(-1, 2), trans_mat + ) + for point in raster_traj: + circle = np.hstack([point - action_marker_size, point + action_marker_size]) + draw.ellipse(circle.tolist(), fill="#808080", outline="#911A12") + + if pred_plan is not None: + pos_raster = agent_to_raster_np(pred_plan["positions"][:, -1], trans_mat) + for pos in pos_raster: + circle = np.hstack([pos - plan_marker_size, pos + plan_marker_size]) + draw.ellipse(circle.tolist(), fill="#FF6B35") + + if plan_samples is not None: + pos_raster = agent_to_raster_np(plan_samples["positions"][0, :, -1], trans_mat) + for pos in pos_raster: + circle = np.hstack([pos - plan_marker_size, pos + plan_marker_size]) + draw.ellipse(circle.tolist(), fill="#FF6B35") + + im = np.asarray(im) + # visualize plan heat map + if pred_plan_info is not None and "location_map" in pred_plan_info: + import matplotlib.pyplot as plt + + cm = plt.get_cmap("jet") + heatmap = pred_plan_info["location_map"][0] + heatmap = heatmap - heatmap.min() + heatmap = heatmap / heatmap.max() + heatmap = cm(heatmap) + + heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)) + heatmap = heatmap.resize(size=(im.shape[1], im.shape[0])) + heatmap = np.asarray(heatmap)[..., :3] + padding = np.ones((im.shape[0], 200, 3), dtype=np.uint8) * 255 + + composite = heatmap.astype(np.float32) * 0.3 + im.astype(np.float32) * 0.7 + composite = composite.astype(np.uint8) + im = np.concatenate((im, padding, heatmap, padding, composite), axis=1) + + return im + + +def draw_agent_boxes( + image, pos, yaw, extent, raster_from_agent, outline_color, fill_color +): + boxes = get_box_world_coords_np(pos, yaw, extent) + boxes_raster = transform_points(boxes, raster_from_agent) + boxes_raster = boxes_raster.reshape((-1, 4, 2)).astype(np.int) + + im = Image.fromarray((image * 255).astype(np.uint8)) + im_draw = ImageDraw.Draw(im) + for b in boxes_raster: + im_draw.polygon( + xy=b.reshape(-1).tolist(), outline=outline_color, fill=fill_color + ) + + im = np.asarray(im).astype(np.float32) / 255.0 + return im + + +def render_state_trajdata( + batch: dict, + batch_idx: int, + action, +) -> np.ndarray: + pos = batch["hist_pos"][batch_idx, -1] + yaw = batch["hist_yaw"][batch_idx, -1] + extent = batch["extent"][batch_idx, :2] + + image = RasterizedMap.to_img( + TensorUtils.to_tensor(batch["maps"][batch_idx]), + [[0, 1, 2], [3, 4], [5, 6]], + ) + + image = draw_agent_boxes( + image, + pos=pos[None, :], + yaw=yaw[None, :], + extent=extent[None, :], + raster_from_agent=batch["raster_from_agent"][batch_idx], + outline_color=COLORS["ego_contour"], + fill_color=COLORS["ego_fill"], + ) + + scene_index = batch["scene_index"][batch_idx] + agent_scene_index = scene_index == batch["scene_index"] + agent_scene_index[batch_idx] = 0 # don't plot ego + + neigh_pos = batch["centroid"][agent_scene_index] + neigh_yaw = batch["world_yaw"][agent_scene_index] + neigh_extent = batch["extent"][agent_scene_index, :2] + + if neigh_pos.shape[0] > 0: + image = draw_agent_boxes( + image, + pos=neigh_pos, + yaw=neigh_yaw[:, None], + extent=neigh_extent, + raster_from_agent=batch["raster_from_world"][batch_idx], + outline_color=COLORS["agent_contour"], + fill_color=COLORS["agent_fill"], + ) + + plan_info = None + plan_samples = None + action_samples = None + if "plan_info" in action.agents_info: + plan_info = TensorUtils.map_ndarray( + action.agents_info["plan_info"], lambda x: x[[batch_idx]] + ) + if "plan_samples" in action.agents_info: + plan_samples = TensorUtils.map_ndarray( + action.agents_info["plan_samples"], lambda x: x[[batch_idx]] + ) + if "action_samples" in action.agents_info: + action_samples = TensorUtils.map_ndarray( + action.agents_info["action_samples"], lambda x: x[[batch_idx]] + ) + + vis_action = TensorUtils.map_ndarray( + action.agents.to_dict(), lambda x: x[batch_idx] + ) + image = draw_actions( + image, + trans_mat=batch["raster_from_agent"][batch_idx], + pred_action=vis_action, + pred_plan_info=plan_info, + ego_action_samples=action_samples, + plan_samples=plan_samples, + action_marker_size=2, + plan_marker_size=3, + ) + return image + + +def get_state_image_with_boxes_l5kit(ego_obs, agents_obs, rasterizer): + yaw = ego_obs["world_yaw"] # set to 0 to fix the video + state_im = rasterizer.rasterize(ego_obs["centroid"], yaw) + + raster_from_world = rasterizer.render_context.raster_from_world( + ego_obs["centroid"], yaw + ) + raster_from_agent = raster_from_world @ ego_obs["world_from_agent"] + + state_im = draw_agent_boxes( + state_im, + agents_obs["centroid"], + agents_obs["world_yaw"][:, None], + agents_obs["extent"][:, :2], + raster_from_world, + outline_color=COLORS["agent_contour"], + fill_color=COLORS["agent_fill"], + ) + + state_im = draw_agent_boxes( + state_im, + ego_obs["centroid"][None], + ego_obs["world_yaw"][None, None], + ego_obs["extent"][None, :2], + raster_from_world, + outline_color=COLORS["ego_contour"], + fill_color=COLORS["ego_fill"], + ) + + return state_im, raster_from_agent, raster_from_world + + +def get_agent_edge(xy, h, extent): + edges = ( + np.array([[0.5, 0.5], [0.5, -0.5], [-0.5, -0.5], [-0.5, 0.5]]) + * extent[np.newaxis, :2] + ) + rotM = np.array([[np.cos(h), -np.sin(h)], [np.sin(h), np.cos(h)]]) + edges = (rotM @ edges[..., np.newaxis]).squeeze(-1) + xy[np.newaxis, :] + return edges + + +def plot_scene_open_loop( + fig: figure, + traj: np.ndarray, + extent: np.ndarray, + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + bbox: Optional[Tuple[float, float, float, float]] = None, + color_scheme="blue_red", + mask=None, +): + Na = traj.shape[0] + static_glyphs = draw_map_elems(fig, vec_map, map_from_world_tf, bbox) + agent_edge = np.stack( + [ + get_agent_edge(traj[i, 0, :2], traj[i, 0, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + agent_patches = defaultdict(lambda: None) + traj_lines = defaultdict(lambda: None) + traj_xy = batch_nd_transform_points_np(traj[..., :2], map_from_world_tf[None]) + + if color_scheme == "blue_red": + agent_color = ["red"] + ["blue"] * (Na - 1) + elif color_scheme == "palette": + palette = bokeh.palettes.Category20[20] + agent_color = ["blueviolet"] + [palette[i % 20] for i in range(Na - 1)] + for i in range(Na): + if mask is None or mask[i]: + agent_patches[i] = fig.patch( + x=agent_edge[i, :, 0], + y=agent_edge[i, :, 1], + color=agent_color[i], + ) + traj_lines[i] = fig.line( + x=traj_xy[i, :, 0], + y=traj_xy[i, :, 1], + color=agent_color[i], + line_width=2, + ) + return static_glyphs, agent_patches, traj_lines + + +def delete_files_in_directory(directory_path): + try: + files = os.listdir(directory_path) + for file in files: + file_path = os.path.join(directory_path, file) + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + print("Error occurred while deleting files.") + + +def make_gif(frame_folder, gif_name, duration=100, loop=0): + frames = [ + Image.open(image) + for image in sorted(glob.glob(f"{frame_folder}/*.png"), key=os.path.getmtime) + ] + frame_one = frames[0] + frame_one.save( + gif_name, + format="GIF", + append_images=frames, + save_all=True, + duration=duration, + loop=loop, + ) + + +def animate_scene_open_loop( + fig: figure, + traj: np.ndarray, + extent: np.ndarray, + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + # rel_bbox: Optional[Tuple[float, float, float, float]] = None, + bbox: Optional[Tuple[float, float, float, float]] = None, + color_scheme="blue_red", + mask=None, + dt=0.1, + tmp_dir="tmp", + gif_name="diffstack_anim.gif", +): + Na, T = traj.shape[:2] + # traj_xy = batch_nd_transform_points_np(traj[..., :2], map_from_world_tf[None]) + # bbox = ( + # [ + # rel_bbox[0] + traj_xy[0, 0, 0], + # rel_bbox[1] + traj_xy[0, 0, 0], + # rel_bbox[2] + traj_xy[0, 0, 1], + # rel_bbox[3] + traj_xy[0, 0, 1], + # ] + # if rel_bbox is not None + # else None + # ) + static_glyphs = draw_map_elems(fig, vec_map, map_from_world_tf, bbox) + agent_edge = np.stack( + [ + get_agent_edge(traj[i, 0, :2], traj[i, 0, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + agent_patches = defaultdict(lambda: None) + traj_lines = defaultdict(lambda: None) + + agent_xy_source = defaultdict(lambda: None) + if color_scheme == "blue_red": + agent_color = ["red"] + ["blue"] * (Na - 1) + elif color_scheme == "palette": + palette = bokeh.palettes.Category20[20] + agent_color = ["blueviolet"] + [palette[i % 20] for i in range(Na - 1)] + for i in range(Na): + if mask is None or mask[i]: + agent_xy_source[i] = ColumnDataSource( + data=dict(x=agent_edge[i, :, 0], y=agent_edge[i, :, 1]) + ) + + agent_patches[i] = fig.patch( + x="x", + y="y", + color=agent_color[i], + source=agent_xy_source[i], + name=f"patch_{i}", + ) + if os.path.exists(tmp_dir): + if os.path.isfile(tmp_dir): + os.remove(tmp_dir) + else: + os.mkdir(tmp_dir) + delete_files_in_directory(tmp_dir) + + for t in range(T): + agent_edge = np.stack( + [ + get_agent_edge(traj[i, t, :2], traj[i, t, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + for i in range(Na): + if mask is None or mask[i]: + new_source_data = dict(x=agent_edge[i, :, 0], y=agent_edge[i, :, 1]) + patch = fig.select_one({"name": f"patch_{i}"}) + patch.data_source.data = new_source_data + export_png(fig, filename=tmp_dir + "/plot_" + str(t) + ".png") + + make_gif(tmp_dir, gif_name, duration=dt * 1000) + delete_files_in_directory(tmp_dir) + return static_glyphs, agent_patches, traj_lines diff --git a/diffstack/utils/visualization.py b/diffstack/utils/visualization.py new file mode 100644 index 0000000..2946890 --- /dev/null +++ b/diffstack/utils/visualization.py @@ -0,0 +1,915 @@ +import numpy as np +import torch + +from collections import defaultdict, OrderedDict +from typing import Dict, Optional, Union, Any, Tuple + +from nuscenes.map_expansion.map_api import NuScenesMap + +from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.batch import AgentBatch, SceneBatch +from trajdata.maps import RasterizedMap + +from diffstack.utils.utils import subsample_future +from diffstack.utils.metrics import compute_ade_pt + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.animation import FuncAnimation +from matplotlib.axes import Axes + + +def legend_unique_labels(ax, **kwargs): + handles, labels = ax.get_legend_handles_labels() + labels, legend_ids = np.unique(labels, return_index=True) + handles = [handles[i] for i in legend_ids] + plt.legend(handles, labels, **kwargs) + + +def plot_plan_input_batch( + batch: Union[AgentBatch, SceneBatch], + batch_idx: int, + ax: Optional[Axes] = None, + legend: bool = True, + show: bool = True, + close: bool = True, +) -> None: + if ax is None: + _, ax = plt.subplots() + + # For now we just convert SceneBatch to AgentBatch + if isinstance(batch, SceneBatch): + batch = batch.to_agent_batch(batch.extras["pred_agent_ind"]) + + agent_name: str = batch.agent_name[batch_idx] + agent_type: AgentType = AgentType(batch.agent_type[batch_idx].item()) + ax.set_title(f"{str(agent_type)}/{agent_name}") + + pred_agent_history_xy: torch.Tensor = batch.agent_hist[batch_idx].cpu() + pred_agent_future_xy: torch.Tensor = batch.agent_fut[batch_idx, :, :2].cpu() + neighbor_hist = batch.neigh_hist[batch_idx].cpu() + neighbor_fut = batch.neigh_fut[batch_idx].cpu() + # The index of the current time step depends on the padding direction when the history is incomplete. + if batch.history_pad_dir == batch.history_pad_dir.AFTER: + pred_agent_cur_ind = batch.agent_hist_len[batch_idx].cpu() - 1 + neighbor_cur_ind = batch.neigh_hist_len[batch_idx].cpu() - 1 + else: + pred_agent_cur_ind = -1 + neighbor_cur_ind = [-1 for _ in range(neighbor_hist.shape[0])] + + robot_ind: torch.Tensor = batch.extras['robot_ind'][batch_idx].cpu() + lane_projection_points: torch.Tensor = batch.extras['lane_projection_points'][batch_idx].cpu() + goal: torch.Tensor = batch.extras['goal'][batch_idx].cpu() + + # Map + if batch.maps is not None: + agent_from_world_tf: torch.Tensor = batch.agents_from_world_tf[batch_idx].cpu() + world_from_raster_tf: torch.Tensor = torch.linalg.inv( + batch.rasters_from_world_tf[batch_idx].cpu() + ) + + agent_from_raster_tf: torch.Tensor = agent_from_world_tf @ world_from_raster_tf + + patch_size: int = batch.maps[batch_idx].shape[-1] + + left_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ + 0 + ].item() + right_extent: float = ( + agent_from_raster_tf @ torch.tensor([patch_size, 0.0, 1.0]) + )[0].item() + bottom_extent: float = ( + agent_from_raster_tf @ torch.tensor([0.0, patch_size, 1.0]) + )[1].item() + top_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ + 1 + ].item() + + ax.imshow( + RasterizedMap.to_img( + batch.maps[batch_idx].cpu(), + # [[0], [1], [2]] + # [[0, 1, 2], [3, 4], [5, 6]], + ), + extent=( + left_extent, + right_extent, + bottom_extent, + top_extent, + ), + alpha=0.3, + ) + + # Lanes + if robot_ind >= 0: + ax.scatter( + lane_projection_points[:, 0], + lane_projection_points[:, 1], + s=15, + c="black", + label="Lane projections", + ) + if 'lanes_near_goal' in batch.extras: + lanes_near_goal = batch.extras['lanes_near_goal'][batch_idx] + ax.plot([], [], c="grey", ls="--", label="Goal lanes") + for lane_near_goal in lanes_near_goal: + ax.plot(lane_near_goal[:, 0], lane_near_goal[:, 1], c="grey", ls="--") + + # Pred agent + ax.plot(pred_agent_history_xy[..., 0], pred_agent_history_xy[..., 1], c="orange", ls="--", label="Agent History") + ax.quiver( + pred_agent_history_xy[..., 0], + pred_agent_history_xy[..., 1], + pred_agent_history_xy[..., -1], + pred_agent_history_xy[..., -2], + color="k", + # scale=50, + width=2e-3, + ) + ax.plot(pred_agent_future_xy[..., 0], pred_agent_future_xy[..., 1], c="violet", label="Agent Future") + ax.scatter(pred_agent_history_xy[pred_agent_cur_ind, 0], pred_agent_history_xy[pred_agent_cur_ind, 1], s=20, c="orangered", label="Agent Current") + + # Ego + goal + if robot_ind >= 0: + ego_hist = neighbor_hist[robot_ind] + ego_fut = neighbor_fut[robot_ind] + ax.plot(ego_hist[:, 0], ego_hist[:, 1], c="olive", ls="--", label="Ego History") + ax.plot(ego_fut[:, 0], ego_fut[:, 1], c="darkgreen", label="Ego Future") + ax.scatter( + ego_hist[None, neighbor_cur_ind[robot_ind], 0], + ego_hist[None, neighbor_cur_ind[robot_ind], 1], + s=20, + c="red", + label="Ego Current", + ) + ax.scatter(goal[None, 0], goal[None, 1], s=15, c="purple", label="Goal") + + # Neighbors + neighbors_idx = [i for i in range(batch.num_neigh[batch_idx]) if i != robot_ind] + if len(neighbors_idx) > 0: + ax.plot([], [], c="olive", ls="--", label="Neighbor History") + for n in neighbors_idx: + ax.plot(neighbor_hist[n][:, 0], neighbor_hist[n, :, 1], c="olive", ls="--") + + ax.plot([], [], c="darkgreen", label="Neighbor Future") + for n in neighbors_idx: + ax.plot(neighbor_fut[n][:, 0], neighbor_fut[n, :, 1], c="darkgreen") + + ax.scatter( + torch.stack([neighbor_hist[n][neighbor_cur_ind[n], 0] for n in neighbors_idx], dim=0), + torch.stack([neighbor_hist[n][neighbor_cur_ind[n], 1] for n in neighbors_idx], dim=0), + s=20, + c="gold", + label="Neighbor Current", + ) + + # Ego conditioning in prediction + if batch.robot_fut is not None and batch.robot_fut.shape[1] > 0: + raise NotImplementedError() + + # Formatting + ax.set_xlabel("x (m)") + ax.set_ylabel("y (m)") + ax.grid(False) + ax.axis("equal") + + if legend: + ax.legend(loc="best", frameon=True) + + if show: + plt.show() + + if close: + plt.close() + + return ax + + + +def plot_plan_result( + plan_x: np.ndarray, + control_x: np.ndarray = None, + ax: Optional[Axes] = None, +) -> None: + if ax is None: + _, ax = plt.subplots() + ax.plot(plan_x[..., 0], plan_x[..., 1], c="blue", ls="-", label="Plan") + if control_x is not None: + ax.plot(control_x[..., 0], control_x[..., 1], c="purple", ls="-", label="Control") + return ax + + +def plot_plan_candidates( + plan_candidates_x: np.ndarray, + ax: Optional[Axes] = None, +) -> None: + if ax is None: + _, ax = plt.subplots() + ax.plot(plan_candidates_x[..., 0].transpose(1, 0), plan_candidates_x[..., 1].transpose(1, 0), c="lightsteelblue", ls="-", label="Plan candidates") + + return ax + + + + +def plot_map(dataroot: Optional[str], map_name: Optional[str], map_patch: Tuple[float, float, float, float], nusc_map: Optional[NuScenesMap], figsize=(24, 24)) -> Tuple[Figure, plt.Axes]: + if nusc_map is None: + assert dataroot is not None and map_name is not None, "Must provide map_obj or path to map file." + nusc_map = NuScenesMap(dataroot, map_name) + bitmap = None #BitMap(dataroot, map_name, 'basemap') + + return nusc_map.render_map_patch(map_patch, + ['drivable_area', + 'road_segment', + # 'road_block', + 'lane', + 'ped_crossing', + 'walkway', + 'stop_line', + 'carpark_area', + 'road_divider', + 'lane_divider'], + alpha=0.05, + render_egoposes_range=False, + bitmap=bitmap, + figsize=figsize, + render_legend=False, + ) + + +def visualize_plan_batch(nusc_maps, scenes, x_t, y_t, plan_data, plot_data, titles, plot_styles, num_plots, planner, ph, planh): + + allowed_plot_styles = ["anim_iters", "compare_futures", "fan_with_pred", "compare_fan_vs_mpc"] + if not all([x in allowed_plot_styles for x in plot_styles]): + raise ValueError(f"Not all requested plot styles are known.\n Requested: {plot_styles}.\n Known: {allowed_plot_styles}") + + plot_margin_m = 60 + plan_color = 'coral' + ego_gt_color = 'royalblue' + gt_color = 'black' + pred_color = 'orange' + pred_gt_color = 'yellow' + plan_gt_color = 'red' + plan_nof_color = 'purple' + plan_nopred_color = 'grey' + plan_tree_color = 'brown' + ego_gthcost_color = 'green' + mpc_plan_color = 'green' + + plan_metrics, plan_iters = plot_data['plan'] + nopred_plan_metrics, nopred_plan_iters = plot_data['nopred_plan'] + # nof_plan_metrics, nof_plan_iters = plot_data['nof_plan'] + gt_plan_metrics, gt_plan_iters = plot_data['gt_plan'] + + output = defaultdict(list) + plotted_inds = [] + + def index_plan_iters(plan_iters, batch_i): + skip_list = ['x', 'u', 'cost', 'all_gt_neighbors', 'gt_neighbors', 'cost_components', 'hcost_components'] + res = {} + for k, v in plan_iters.items(): + if k in skip_list: + continue + if v is None: + res[k] = None + elif isinstance(v, list) or isinstance(v, tuple): + res[k] = v[batch_i].cpu().numpy() + elif v.ndim == 1: + res[k] = v[batch_i].cpu().numpy() + else: + res[k] = v[:, batch_i].cpu().numpy() + return res + + def apply_offset(xy, offset_x): + if offset_x is not None: + while offset_x.ndim < xy.ndim: + offset_x = offset_x[None] + xy = xy[..., :2] + offset_x[..., :2] + return xy + + def plot_traj(ax, xy, offset_x, label=None, c=None, plot_dot=True, linewidth=1.5, **kwargs): + xy = apply_offset(xy, offset_x) + if plot_dot: + ax.scatter(xy[[0], ..., 0], xy[[0], ..., 1], c=c, s=80.0) + ax_plots = ax.plot(xy[..., 0], xy[..., 1], label=label, c=c, linewidth=linewidth, **kwargs) + return ax_plots + + def plot_preds(ax, xy, offset_x, probs=None, label=None, c=None, linewidth=0.75, alphas=None, **kwargs): + xy = apply_offset(xy, offset_x) + if probs is None: + ax_plots = ax.plot(xy[..., 0], xy[..., 1], label=label, c=c, linewidth=linewidth, **kwargs) + else: + # Normalize probs to 0.2...1 + if alphas is None: + alphas = probs / probs.max() * 0.6 + 0.4 + # Unfortunately there is no support for multiple alpha values so we need to loop + ax_plots = [] + for i in range(xy.shape[1]): + ax_plots.extend(ax.plot(xy[:, i, 0], xy[:, i, 1], label=label, c=c, linewidth=linewidth, alpha=alphas[i], **kwargs)) + return ax_plots + + def plot_lane(ax, lanes, offset_x, label="lane", c='black', marker='x', s=4, **kwargs): + lanes = apply_offset(lanes, offset_x) + ax.scatter(lanes[..., 0], lanes[..., 1], label=label, c=c, marker=marker, s=s, linewidths=1, **kwargs) + + plan_batch_filter = plan_iters['plan_batch_filter'] + pred_mus_batch = plot_data['y_dists'].mus[0, plan_batch_filter].cpu().numpy() # (b, t, N, 2) + pred_probs_batch = torch.exp(plot_data['y_dists'].log_pis[0, plan_batch_filter, 0]).cpu().numpy() # (b, N) + gt_pred_xy_batch = y_t[plan_batch_filter].cpu().numpy() + gt_pred_target_xy_batch = plot_data['y_for_pred'][plan_batch_filter].cpu().numpy() # differs from g + x_t_batch = x_t[plan_batch_filter, -1:, :2].cpu().numpy() + pred_ade_unbiased = compute_ade_pt(plot_data['predictions'], y_t)[plan_batch_filter] + pred_ade_biased = compute_ade_pt(plot_data['predictions'], plot_data['y_for_pred'])[plan_batch_filter] + cost_components_batch = plan_iters['cost_components'].mean(0).cpu().numpy() + hcost_components_batch = plan_iters['hcost_components'].mean(0).cpu().numpy() + + assert len(gt_plan_metrics['hcost']) == len(plan_metrics['hcost']) + assert len(pred_ade_unbiased) == len(plan_metrics['hcost']) + assert cost_components_batch.shape[0] == gt_pred_xy_batch.shape[0] + assert hcost_components_batch.shape[0] == gt_pred_xy_batch.shape[0] + + for batch_i in range(gt_pred_xy_batch.shape[0]): + # Limit to number of requested plots + if num_plots > 0 and len(plotted_inds) >= num_plots: + break + + scene = scenes[batch_i] + offset_x = np.array([scene.x_min, scene.y_min, 0., 0.]) + plan_iters_i = index_plan_iters(plan_iters, batch_i) + gt_neighbors_xy = plan_iters['gt_neighbors'][:, :, batch_i].cpu().numpy() + + gt_pred_xy = gt_pred_xy_batch[batch_i] + gt_pred_target_xy = gt_pred_target_xy_batch[batch_i] + pred_mus = pred_mus_batch[batch_i] + lanes_xy = plan_iters_i['lanes'] + lane_points_xy = plan_iters_i['lane_points'] + gt_neighbors_xy = gt_neighbors_xy[np.logical_not(np.isnan(gt_neighbors_xy[:, 0, 0]))] + + pred_t_xy = x_t_batch[batch_i] + gt_pred_xy = np.concatenate([pred_t_xy, gt_pred_xy], axis=0) + pred_mus = np.concatenate([pred_t_xy[:, None].repeat(pred_mus.shape[1], axis=1), pred_mus], axis=0) + + if planner in ['mpc', 'fan_mpc']: + last_plan_x = plan_iters['x'][-1][:, batch_i].cpu().numpy() + last_nopred_plan_x = nopred_plan_iters['x'][-1][:, batch_i].cpu().numpy() + # last_nof_plan_x = nof_plan_iters['x'][-1][:, batch_i].cpu().numpy() + last_gt_plan_x = gt_plan_iters['x'][-1][:, batch_i].cpu().numpy() + plan_gt_xy = plan_iters_i['x_gt'][:, :2] + plan_gt_xy = np.concatenate([last_plan_x[:1, :2], plan_gt_xy], axis=0) # append t0 + plan_cost_i = plan_iters['cost'][-1][batch_i].cpu().numpy()/plan_gt_xy.shape[0] + # # skip boring examples + # if len(plan_iters_i['gt_neighbors']) < 1: + # continue + # is_interesting = False + # if (nof_plan_metrics['hcost'][batch_i] < gt_plan_metrics['hcost'][batch_i] - 0.01) and gt_plan_converged[batch_i]: + # is_interesting = True + # if nopred_plan_metrics['hcost'][batch_i] > gt_plan_metrics['hcost'][batch_i] + 0.1: + # is_interesting = True + # if not is_interesting: + # continue + + # add present + + # Filter + # TODO + plotted_inds.append(batch_i) + + label = ( + f"{titles[batch_i]} \n" + + f"#{batch_i} n={len(gt_neighbors_xy)} conv={str(plan_iters_i['converged'])} " + + f"pred_ade={pred_ade_biased[batch_i]:.3f} {pred_ade_unbiased[batch_i]:.3f} plan_mse={plan_metrics['mse'][batch_i]:.3f} plan_mcost={plan_cost_i:.3f} plan_hcost={plan_metrics['hcost'][batch_i]:.3f} \n" + + # f"nof_plan_hcost={nof_plan_metrics['hcost'][batch_i]:.3f} nopred_plan_hcost={nopred_plan_metrics['hcost'][batch_i]:.3f} gtplan_hcost={gt_plan_metrics['hcost'][batch_i]:.3f} \n" + + f"nopred_plan_hcost={nopred_plan_metrics['hcost'][batch_i]:.3f} gtplan_hcost={gt_plan_metrics['hcost'][batch_i]:.3f} \n" + + f"nopred_regret_now={nopred_plan_metrics['hcost'][batch_i]-gt_plan_metrics['hcost'][batch_i]:.4f} " + + f"nopred_regret_cache={plan_data['nopred_plan_hcost'][batch_i]-plan_data['gt_plan_hcost'][batch_i]:.4f}") + print (label) + + # # Debug cost change + # print (f"gt plan internal cost: {torch.stack([c[batch_i].cpu() for c in gt_plan_iters['cost']])}") + # print (f"nof plan internal cost: {torch.stack([c[batch_i].cpu() for c in nof_plan_iters['cost']])}") + # print (f"nopred plan internal cost: {torch.stack([c[batch_i].cpu() for c in nopred_plan_iters['cost']])}") + # print (f"pred plan internal cost: {torch.stack([c[batch_i].cpu() for c in plan_iters['cost']])}") + + # Animate planning process + if "anim_iters" in plot_styles: + ref_agent_x = last_plan_x + offset_x[None] + def anim_background(): + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_maps[plan_data['map_name'][batch_i]], + # map_name=env_helper.get_map_name_from_sample_token(), + map_patch=(ref_agent_x[:, 0].min() - plot_margin_m, + ref_agent_x[:, 1].min() - plot_margin_m, + ref_agent_x[:, 0].max() + plot_margin_m, + ref_agent_x[:, 1].max() + plot_margin_m)) + plot_lane(ax, lanes_xy, offset_x, c="pink") + if lane_points_xy is not None: + plot_lane(ax, lane_points_xy, offset_x, c="black") + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='pred', c=pred_color, probs=pred_probs_batch[batch_i]) + plot_traj(ax, gt_pred_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=True) + # In the case of prediction target offset, this will be different. + plot_traj(ax, gt_pred_target_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=False) + plt.suptitle(label, fontsize=10) + return fig, ax + + def animate_plan_iters(fig, ax, plan_iters, *args, **kwargs): + xy = plan_iters['x'][0][:, batch_i].cpu().numpy() + ax_plot, = plot_traj(ax, xy, offset_x, *args, **kwargs) + ax_text = plt.text(0.05, 0.05, "Iter 0", transform=ax.transAxes) + def init(): + return ax_plot, + def animate(i): + i = min(i, len(plan_iters['x'])-1) + xy = plan_iters['x'][i][:, batch_i].cpu().numpy() + xy = apply_offset(xy, offset_x) + ax_plot.set_data(xy[..., 0], xy[..., 1]) + ax_text.set_text(f"Iter {i}; cost={plan_iters['cost'][i][batch_i].cpu().numpy():.3f}") + return ax_plot, + anim = FuncAnimation(fig, animate, init_func=init, + frames=len(plan_iters['x'])+2, interval=60, blit=False) + plt.show() + plt.pause(1.0) + print("anim done") + return anim + + # fig, ax = anim_background() + # anim1 = animate_plan_iters(fig, ax, nopred_plan_iters, label='nopred plan', c=plan_nopred_color) + # output['anim_iters_nopred'].append(anim1) + + # fig, ax = anim_background() + # anim2 = animate_plan_iters(fig, ax, gt_plan_iters, label='gt plan', c=plan_gt_color) + # output['anim_iters_gt'].append(anim2) + + fig, ax = anim_background() + anim2 = animate_plan_iters(fig, ax, plan_iters, label='mpc plan', c=mpc_plan_color) + output['anim_iters_mpc'].append(anim2) + + # anim1.save(save_plot_paths[batch_i] + '-nopred_plan.gif',writer='imagemagick', fps=2) + # anim2.save(save_plot_paths[batch_i] + '-gt_plan.gif',writer='imagemagick', fps=2) + # plt.show() + # plt.pause(1.0) + + # Plot compare plans + if "compare_futures" in plot_styles: + ref_agent_x = last_plan_x + offset_x[None] + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_maps[plan_data['map_name'][batch_i]], + map_patch=(ref_agent_x[:, 0].min() - plot_margin_m, + ref_agent_x[:, 1].min() - plot_margin_m, + ref_agent_x[:, 0].max() + plot_margin_m, + ref_agent_x[:, 1].max() + plot_margin_m)) + + plot_lane(ax, lanes_xy, offset_x, c="pink") + if lane_points_xy is not None: + plot_lane(ax, lane_points_xy, offset_x, c="black") + + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='pred', c=pred_color, probs=pred_probs_batch[batch_i]) + plot_traj(ax, gt_pred_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=True) + # In the case of prediction target offset, this will be different. + plot_traj(ax, gt_pred_target_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=False) + plot_traj(ax, gt_neighbors_xy.transpose((1, 0, 2)), offset_x, label='gt', c=gt_color) + # plot_traj(ax, last_nof_plan_x[:, :2], offset_x, label='nofuture plan', c=plan_nof_color, plot_dot=False) + plot_traj(ax, plan_gt_xy, offset_x, label='ego gt', c=ego_gt_color, plot_dot=True, linewidth=2.0) + plot_traj(ax, last_gt_plan_x[:, :2], offset_x, label='gt plan', c=plan_gt_color, plot_dot=False) + plot_traj(ax, last_nopred_plan_x[:, :2], offset_x, label='nopred plan', c=plan_nopred_color, plot_dot=False) + plot_traj(ax, last_plan_x[:, :2], offset_x, label='mpc plan', c=mpc_plan_color, plot_dot=False) + + # ax.scatter(ref_agent_x[[0], 0], ref_agent_x[[0], 1], label='Ego', c=ego_color) + # ax.plot(ref_agent_x[:, 0], ref_agent_x[:, 1], label='Ego Motion Plan', c=ego_color) + + plt.suptitle(label, fontsize=10) + + legend_unique_labels(plt.gca()) + + output['compare_futures'].append(fig) + # plt.savefig(save_plot_paths[batch_i] + '-plan.png') + # plt.show() + + if planner in ['fan', 'fan_mpc']: + # Plot fan planner + traj_xy = plan_iters_i['traj_xu'][..., :2] # N, T+1, 6 + traj_costs = plan_iters_i['traj_cost'] + plan_i = traj_costs.argmin() + fan_plan_xy = traj_xy[plan_i] + label_goaldist_xy = traj_xy[plan_iters_i['label_goaldist']] + label_hcost_xy = traj_xy[plan_iters_i['label_hcost']] + plan_gt_xy = plan_iters_i['x_gt'][:, :2] + plan_gt_xy = np.concatenate([fan_plan_xy[:1, :2], plan_gt_xy], axis=0) # append t0 + + # Filter + # if not plan_iters_i['converged']: + # continue + plotted_inds.append(batch_i) + + # TODO mse and hcost are wrong for fan_mpc + label = ( + f"{titles[batch_i]} \n" + + f"#{batch_i} n={len(gt_neighbors_xy)} conv={str(plan_iters_i['converged'])} " + + f"pred_ade={pred_ade_biased[batch_i]:.3f} {pred_ade_unbiased[batch_i]:.3f} plan_mse={plan_metrics['mse'][batch_i]:.3f} plan_cost={traj_costs[plan_i]:.3f} plan_hcost={plan_metrics['hcost'][batch_i]:.3f} \n" + + f"#candid={traj_xy.shape[0]} label_goaldist={plan_iters_i['label_goaldist']} label_hcost={plan_iters_i['label_hcost']} lowest_cost={plan_i} plan_loss={plan_iters_i['plan_loss']:.3f} \n" + + f"cost " + " ".join([f"{c:.2f}" for c in cost_components_batch[batch_i]]) + "\n" + + f"hcost " + " ".join([f"{c:.2f}" for c in hcost_components_batch[batch_i]]) + ) + print (label) + + if "fan_with_pred" in plot_styles: + ref_agent_x = plan_gt_xy + offset_x[None, :2] + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_maps[plan_data['map_name'][batch_i]], + map_patch=np.array((ref_agent_x[:, 0].min() - plot_margin_m, + ref_agent_x[:, 1].min() - plot_margin_m, + ref_agent_x[:, 0].max() + plot_margin_m, + ref_agent_x[:, 1].max() + plot_margin_m))) + + plot_lane(ax, lanes_xy, offset_x, c="pink") + if lane_points_xy is not None: + plot_lane(ax, lane_points_xy, offset_x, c="black") + + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='pred', c=pred_color, probs=pred_probs_batch[batch_i]) + plot_traj(ax, gt_pred_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=True) + # In the case of prediction target offset, this will be different. + plot_traj(ax, gt_pred_target_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=False) + + plot_traj(ax, gt_neighbors_xy.transpose((1, 0, 2)), offset_x, label='gt', c=gt_color) + + plot_traj(ax, plan_gt_xy, offset_x, label='ego gt', c=ego_gt_color, plot_dot=True) + + # Plot candidate targets only + plot_lane(ax, traj_xy.transpose((1, 0, 2)), offset_x, label='trajectory fan', c=plan_tree_color) + # Plot candidate trajectories + # plot_traj(ax, traj_xy.transpose((1, 0, 2)), offset_x, label='tree', c=plan_tree_color) + # plot_traj(ax, gt_plan_x[:, :2], offset_x, label='gt plan', c=plan_color, plot_dot=False) + plot_traj(ax, label_goaldist_xy, offset_x, label='label goaldist', c=ego_gt_color, plot_dot=False) + plot_traj(ax, label_hcost_xy, offset_x, label='label hcost', c=ego_gthcost_color, plot_dot=False) + plot_traj(ax, fan_plan_xy, offset_x, label='ego plan', c=plan_color, plot_dot=False) + + + # ax.scatter(ref_agent_x[[0], 0], ref_agent_x[[0], 1], label='Ego', c=ego_color) + # ax.plot(ref_agent_x[:, 0], ref_agent_x[:, 1], label='Ego Motion Plan', c=ego_color) + + plt.suptitle(label, fontsize=10) + + legend_unique_labels(plt.gca(), loc='lower left') + + output['fan_with_pred'].append(fig) + # plt.show() + + if "compare_fan_vs_mpc" in plot_styles: + # Plot fan_mpc planner + assert planner == 'fan_mpc' + + traj_xy = plan_iters_i['traj_xu'][..., :2] # N, T+1, 6 + traj_costs = plan_iters_i['traj_cost'] + plan_i = traj_costs.argmin() + fan_plan_xy = traj_xy[plan_i] + label_goaldist_xy = traj_xy[plan_iters_i['label_goaldist']] + label_hcost_xy = traj_xy[plan_iters_i['label_hcost']] + plan_gt_xy = plan_iters_i['x_gt'][:, :2] + plan_gt_xy = np.concatenate([fan_plan_xy[:1, :2], plan_gt_xy], axis=0) # append t0 + fan_mse = np.square(subsample_future(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() + mpc_plan_cost = plan_iters['cost'][-1][batch_i].cpu().numpy() / plan_gt_xy.shape[0] # from sum over time to mean over time + + # Filter + # if not plan_iters_i['converged']: + # continue + plotted_inds.append(batch_i) + + label = ( + f"{titles[batch_i]} \n" + + f"#{batch_i} n={len(gt_neighbors_xy)} conv={str(plan_iters_i['fan_converged'])} {str(plan_iters_i['mpc_converged'])} " + + f"#candid={traj_xy.shape[0]} label_goaldist={plan_iters_i['label_goaldist']} label_hcost={plan_iters_i['label_hcost']} lowest_cost={plan_i} plan_loss={plan_iters_i['plan_loss']:.3f} \n" + + f"fan_mse={fan_mse:.3f} fan_mcost={traj_costs[plan_i]/plan_gt_xy.shape[0]:.3f} fan_hcost={plan_iters_i['traj_hcost'][plan_i]/plan_gt_xy.shape[0]:.3f} \n" + + f"mpc_mse={plan_metrics['mse'][batch_i]:.3f} mpc_mcost={mpc_plan_cost:.3f} mpc_hcost={plan_metrics['hcost'][batch_i]:.3f} \n" + + f"fan cost " + " ".join([f"{c:.2f}" for c in plan_iters['fan_cost_components'][:, batch_i].mean(0).cpu().numpy()]) + "\n" + + f"mpc cost " + " ".join([f"{c:.2f}" for c in cost_components_batch[batch_i]]) + ) + print (label) + + dist = plot_data['y_dists'] + # TODO requires batch=1 + # original_batch_i = torch.arange(dist.mus.shape[1], device=dist.mus.device)[plan_batch_filter][batch_i] + dist.mus = dist.mus.detach().clone() + dist.mus[..., 0] += offset_x[0] + dist.mus[..., 1] += offset_x[1] + ml_k = np.argmax(pred_probs_batch[batch_i]) + pred_ml = pred_mus[:, ml_k] + + for rep_i in range(4): + ref_agent_x = fan_plan_xy + offset_x[None, :2] + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_maps[plan_data['map_name'][batch_i]], + map_patch=np.array((ref_agent_x[:, 0].min() - plot_margin_m/2, + ref_agent_x[:, 1].min() - plot_margin_m/2, + ref_agent_x[:, 0].max() + plot_margin_m/2, + ref_agent_x[:, 1].max() + plot_margin_m/2)), + # figsize=(8,8), + figsize=(24,24), + ) + + dist.log_pis = plot_data['y_dists'].log_pis.detach().clone() + if rep_i == 0: + probs = pred_probs_batch[batch_i] + alphas = probs / probs.max() # * 0.6 + 0.4 + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='predictions', c='orange', probs=pred_probs_batch[batch_i], alphas=alphas, linewidth=2.0) + + # visualize_distribution2(plt.gca(), dist, pi_threshold=0.05, color=pred_color, pi_alpha=0.1, topn=1) + elif rep_i == 1: + probs = pred_probs_batch[batch_i] + alphas = probs / probs.max() * 0.8 + 0.2 + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='predictions', c='orange', probs=pred_probs_batch[batch_i], alphas=alphas, linewidth=2.0) + # visualize_distribution2(plt.gca(), dist, pi_threshold=0.05, color=pred_color, pi_alpha=0.1, topn=3) + elif rep_i == 2: + probs = pred_probs_batch[batch_i] + alphas = probs / probs.max() * 0.9 + 0.1 + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='predictions', c='orange', probs=pred_probs_batch[batch_i], alphas=alphas, linewidth=2.0) + # visualize_distribution2(plt.gca(), dist, pi_threshold=0.05, color=pred_color, pi_alpha=0.1) + else: + probs = pred_probs_batch[batch_i] + alphas = probs / probs.max() * 0.95 + 0.05 + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='predictions', c='orange', probs=pred_probs_batch[batch_i], alphas=alphas, linewidth=2.0) + # visualize_distribution2(plt.gca(), dist, pi_threshold=0.05, color=pred_color, pi_alpha=0.1) + + # visualize_distribution2(plt.gca(), dist, pi_threshold=0.05, color=pred_color, pi_alpha=0.1) + + # plot_preds(ax, pred_mus[:, :, :2], offset_x, label='predictions', c='orange', probs=pred_probs_batch[batch_i], linewidth=2.0) + plot_traj(ax, gt_pred_xy[:, :2], offset_x, label='gt_future', c=gt_color, linewidth=2.5, plot_dot=True) + plot_traj(ax, gt_neighbors_xy.transpose((1, 0, 2)), offset_x, label='gt_future', linewidth=2.5, c=gt_color) + plot_traj(ax, plan_gt_xy, offset_x, label='gt_future', c=gt_color, linewidth=2.5, plot_dot=True) + + # plot_traj(ax, fan_plan_xy, offset_x, label='fan plan', c=plan_color, linewidth=2.0, plot_dot=False) + plot_traj(ax, last_plan_x[:, :2], offset_x, label='ego_plan', c='red', linewidth=2.5, plot_dot=False) + # TODO use plot_lane to add markers (at subsampled timesteps) to indicate velocity + + + + plot_traj(ax, pred_ml[:, :2], offset_x, label='dist_prediction', linewidth=2.5, c=pred_color, plot_dot=False) # only to make it appear in legend + plot_traj(ax, pred_ml[:, :2], offset_x, label='ml_prediction', linewidth=2.5, c='yellow', plot_dot=False) + + + # Hide grid lines + ax.grid(False) + # Hide axes ticks + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect('equal') + + if rep_i == 0: + handles, labels = plt.gca().get_legend_handles_labels() + labels, legend_ids = np.unique(labels, return_index=True) + # handles = [handles[i] for i in legend_ids] + oredered_handles = [] + oredered_labels = ['gt_future', 'ego_plan', 'ml_prediction', 'dist_prediction', 'drivable_area', 'lane', + 'lane_divider', 'walkway', 'ped_crossing', 'stop_line', 'road_divider', 'road_segment', 'carpark_area', + ] + for lb in oredered_labels: + finds = np.where(labels==lb)[0] + if len(finds) < 1: + continue + i = finds[0] + oredered_handles.append(handles[legend_ids[i]]) + plt.legend(oredered_handles, oredered_labels, loc='lower left', framealpha=1.) + output['compare_fan_vs_mpc_label'].append(fig) + elif rep_i == 1: + output['compare_fan_vs_mpc'].append(fig) + else: + output['compare_fan_vs_mpc'+str(rep_i)].append(fig) + + + if "OLD_compare_fan_vs_mpc" in plot_styles: + # Plot fan_mpc planner + assert planner == 'fan_mpc' + + traj_xy = plan_iters_i['traj_xu'][..., :2] # N, T+1, 6 + traj_costs = plan_iters_i['traj_cost'] + plan_i = traj_costs.argmin() + fan_plan_xy = traj_xy[plan_i] + label_goaldist_xy = traj_xy[plan_iters_i['label_goaldist']] + label_hcost_xy = traj_xy[plan_iters_i['label_hcost']] + plan_gt_xy = plan_iters_i['x_gt'][:, :2] + plan_gt_xy = np.concatenate([fan_plan_xy[:1, :2], plan_gt_xy], axis=0) # append t0 + fan_mse = np.square(subsample_future(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() + mpc_plan_cost = plan_iters['cost'][-1][batch_i].cpu().numpy() / plan_gt_xy.shape[0] # from sum over time to mean over time + + # Filter + # if not plan_iters_i['converged']: + # continue + plotted_inds.append(batch_i) + + label = ( + f"{titles[batch_i]} \n" + + f"#{batch_i} n={len(gt_neighbors_xy)} conv={str(plan_iters_i['fan_converged'])} {str(plan_iters_i['mpc_converged'])} " + + f"#candid={traj_xy.shape[0]} label_goaldist={plan_iters_i['label_goaldist']} label_hcost={plan_iters_i['label_hcost']} lowest_cost={plan_i} plan_loss={plan_iters_i['plan_loss']:.3f} \n" + + f"fan_mse={fan_mse:.3f} fan_mcost={traj_costs[plan_i]/plan_gt_xy.shape[0]:.3f} fan_hcost={plan_iters_i['traj_hcost'][plan_i]/plan_gt_xy.shape[0]:.3f} \n" + + f"mpc_mse={plan_metrics['mse'][batch_i]:.3f} mpc_mcost={mpc_plan_cost:.3f} mpc_hcost={plan_metrics['hcost'][batch_i]:.3f} \n" + + f"fan cost " + " ".join([f"{c:.2f}" for c in plan_iters['fan_cost_components'][:, batch_i].mean(0).cpu().numpy()]) + "\n" + + f"mpc cost " + " ".join([f"{c:.2f}" for c in cost_components_batch[batch_i]]) + ) + print (label) + + ref_agent_x = fan_plan_xy + offset_x[None, :2] + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_maps[plan_data['map_name'][batch_i]], + map_patch=np.array((ref_agent_x[:, 0].min() - plot_margin_m/2, + ref_agent_x[:, 1].min() - plot_margin_m/2, + ref_agent_x[:, 0].max() + plot_margin_m/2, + ref_agent_x[:, 1].max() + plot_margin_m/2))) + + plot_lane(ax, lanes_xy, offset_x, c="pink") + if lane_points_xy is not None: + plot_lane(ax, lane_points_xy, offset_x, c="black") + + plot_preds(ax, pred_mus[:, :, :2], offset_x, label='pred', c=pred_color, probs=pred_probs_batch[batch_i]) + plot_traj(ax, gt_pred_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=True) + # In the case of prediction target offset, this will be different. + plot_traj(ax, gt_pred_target_xy[:, :2], offset_x, label='pred gt', c=pred_gt_color, plot_dot=False) + + plot_traj(ax, gt_neighbors_xy.transpose((1, 0, 2)), offset_x, label='gt', c=gt_color) + + plot_traj(ax, plan_gt_xy, offset_x, label='ego gt', c=ego_gt_color, linewidth=2.0, plot_dot=True) + + plot_traj(ax, fan_plan_xy, offset_x, label='fan plan', c=plan_color, linewidth=2.0, plot_dot=False) + plot_traj(ax, last_plan_x[:, :2], offset_x, label='mpc plan', c=mpc_plan_color, linewidth=1.5, plot_dot=False) + # TODO use plot_lane to add markers (at subsampled timesteps) to indicate velocity + + plt.suptitle(label, fontsize=10) + + legend_unique_labels(plt.gca(), loc='lower left') + + output['compare_fan_vs_mpc'].append(fig) + + # Remove duplicates + plotted_inds = list(OrderedDict.fromkeys(plotted_inds)) + return output, plotted_inds + + +def visualize_closed_loop(sim_hist, scenario_metrics, scene, nusc_maps, animate=True): + """ + all_sim_hist[t][node_type] --> [N][state_dim] + """ + plot_margin_m = 30 + ego_log_color = 'royalblue' + ego_sim_color = 'blue' + gt_color = 'black' + pred_color = 'orange' + pred_gt_color = 'yellow' + plan_color = 'red' + + output = defaultdict(list) + + def apply_offset(xy, offset_x): + if offset_x is not None and xy is not None: + while offset_x.ndim < xy.ndim: + offset_x = offset_x[None] + xy = xy[..., :2] + offset_x[..., :2] + return xy + + def plot_traj(ax, xy, offset_x, label=None, c=None, plot_dot=True, dot_last=False, linewidth=1.5, **kwargs): + xy = apply_offset(xy, offset_x) + if plot_dot: + dot_ind = -1 if dot_last else 0 + ax_scatter = ax.scatter(xy[[dot_ind], ..., 0], xy[[dot_ind], ..., 1], c=c, s=80.0) + else: + ax_scatter = None + ax_plots = ax.plot(xy[..., 0], xy[..., 1], label=label, c=c, linewidth=linewidth, **kwargs) + return ax_plots, ax_scatter + + def plot_preds(ax, xy, offset_x, probs=None, label=None, c=None, linewidth=0.75, alphas=None, **kwargs): + xy = apply_offset(xy, offset_x) + if probs is None: + ax_plots = ax.plot(xy[..., 0], xy[..., 1], label=label, c=c, linewidth=linewidth, **kwargs) + else: + # Normalize probs to 0.2...1 + if alphas is None: + alphas = probs / probs.max() * 0.6 + 0.4 + # Unfortunately there is no support for multiple alpha values so we need to loop + ax_plots = [] + for i in range(xy.shape[1]): + ax_plots.extend(ax.plot(xy[:, i, 0], xy[:, i, 1], label=label, c=c, linewidth=linewidth, alpha=alphas[i], **kwargs)) + return ax_plots + + def plot_lane(ax, lanes, offset_x, label="lane", c='black', marker='x', s=4, **kwargs): + lanes = apply_offset(lanes, offset_x) + return ax.scatter(lanes[..., 0], lanes[..., 1], label=label, c=c, marker=marker, s=s, linewidths=1, **kwargs) + + offset_x = np.array([scene.x_min, scene.y_min, 0., 0.]) + nusc_map = nusc_maps[scene.map_name] + + # label = ( + # f"{titles[batch_i]} \n" + + # f"#{batch_i} n={len(gt_neighbors_xy)} conv={str(plan_iters_i['converged'])} " + + # f"pred_ade={pred_ade_biased[batch_i]:.3f} {pred_ade_unbiased[batch_i]:.3f} plan_mse={plan_metrics['mse'][batch_i]:.3f} plan_mcost={plan_cost_i:.3f} plan_hcost={plan_metrics['hcost'][batch_i]:.3f} \n" + + # # f"nof_plan_hcost={nof_plan_metrics['hcost'][batch_i]:.3f} nopred_plan_hcost={nopred_plan_metrics['hcost'][batch_i]:.3f} gtplan_hcost={gt_plan_metrics['hcost'][batch_i]:.3f} \n" + + # f"nopred_plan_hcost={nopred_plan_metrics['hcost'][batch_i]:.3f} gtplan_hcost={gt_plan_metrics['hcost'][batch_i]:.3f} \n" + + # f"nopred_regret_now={nopred_plan_metrics['hcost'][batch_i]-gt_plan_metrics['hcost'][batch_i]:.4f} " + + # f"nopred_regret_cache={plan_data['nopred_plan_hcost'][batch_i]-plan_data['gt_plan_hcost'][batch_i]:.4f}") + # print (label) + label = f"Closed-loop" + + + def subsample_history(sim_hist, i, offset_x): + plan_xu = np.concatenate([np.concatenate(sim_hist['plan_x'][:i+1], axis=0), np.concatenate(sim_hist['plan_u'][:i+1], axis=0)], axis=-1) # T+1, b + gt_ego = plan_xu[..., :2] + log_ego = np.stack(sim_hist['logged_x'][:i+1], axis=0)[..., :2] + + plan_all_gt_neighbors = np.concatenate(sim_hist['plan_all_gt_neighbors_batch'][i:i+1], axis=1) # N, 1 -- only last step, otherwise need to deal with nans and association + goal_batch = sim_hist['logged_x'][-1][..., :2] + lanes = np.concatenate(sim_hist['lanes'][i], axis=0) # T+1, + # TODO support plotting predictions. For now return empty prediction structure. + empty_mus_batch = np.zeros((0, plan_all_gt_neighbors.shape[1], 1, 1, 2), dtype=np.float32) + empty_logp_batch = np.zeros((0, 1, 1), dtype=np.float32) + lane_points = None + + # Separate predicted agent and gt neighbors + neighbor_invalid = np.isnan(plan_all_gt_neighbors).any(axis=2).any(axis=1) + plan_all_gt_neighbors = plan_all_gt_neighbors[neighbor_invalid == False] + gt_neighbors = plan_all_gt_neighbors[:-1] + gt_pred = plan_all_gt_neighbors[-1:] + + return ( + None, apply_offset(gt_ego, offset_x), apply_offset(log_ego, offset_x), + apply_offset(gt_neighbors, offset_x), apply_offset(gt_pred, offset_x), + apply_offset(empty_mus_batch, offset_x), apply_offset(empty_logp_batch, offset_x), apply_offset(goal_batch, offset_x), + apply_offset(lanes, offset_x), apply_offset(lane_points, offset_x)) + + # Concatenate all planned trajectory into a dummy collection of reference points + ref_agent_x = [] + for plan_x in sim_hist['plan_x']: + ref_agent_x.append(plan_x[:, :4]) + ref_agent_x = np.concatenate(ref_agent_x, axis=0) + offset_x[None] + + def anim_background(): + fig, ax = plot_map(dataroot=None, map_name=None, nusc_map=nusc_map, + # map_name=env_helper.get_map_name_from_sample_token(), + map_patch=(ref_agent_x[:, 0].min() - plot_margin_m, + ref_agent_x[:, 1].min() - plot_margin_m, + ref_agent_x[:, 0].max() + plot_margin_m, + ref_agent_x[:, 1].max() + plot_margin_m), figsize=(6, 6)) + plt.suptitle(label, fontsize=10) + ax.set_aspect('equal') + return fig, ax + + def plot_sim_state(fig, ax, plan_xu, gt_ego, log_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points): + ax_lane = plot_lane(ax, lanes, None, c="pink") + ax_gt_agent_traj, ax_gt_agent_dots = plot_traj(ax, gt_neighbors.transpose((1, 0, 2)), None, dot_last=True, label='gt', c=gt_color) + ax_gt_pred_traj, ax_gt_pred_dots = plot_traj(ax, gt_pred.transpose((1, 0, 2)), None, dot_last=True, label='gt pred', c=pred_gt_color) + # plot_preds(ax, pred_mus[:, :, :2], offset_x, label='pred', c=pred_color, probs=pred_probs_batch[batch_i]) + ax_ego_traj, ax_ego_dots = plot_traj(ax, gt_ego, None, dot_last=True, label='ego sim', c=ego_sim_color) + ax_ego_log_traj, ax_ego_log_dots = plot_traj(ax, log_ego, None, plot_dot=False, dot_last=True, label='ego sim', c=ego_log_color) + ax_text = plt.text(0.05, 0.05, "Step 0", transform=ax.transAxes) + return ax_lane, ax_gt_agent_traj, ax_gt_agent_dots, ax_gt_pred_traj, ax_gt_pred_dots, ax_ego_traj, ax_ego_dots, ax_ego_log_traj, ax_ego_log_dots, ax_text + + def animate_plan_iters(fig, ax, sim_hist): + + plan_xu, gt_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points = subsample_history(sim_hist, 0, offset_x) + ax_lane, ax_gt_agent_traj, ax_gt_agent_dots, ax_gt_pred_traj, ax_gt_pred_dots, ax_ego_traj, ax_ego_dots, ax_ego_log_traj, ax_ego_log_dots, ax_text = plot_sim_state(fig, ax, plan_xu, gt_ego, log_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points) + def init(): + return ax_lane, ax_gt_agent_traj, ax_gt_agent_dots, ax_gt_pred_traj, ax_gt_pred_dots, ax_ego_traj, ax_ego_dots, ax_ego_log_traj, ax_ego_log_dots, ax_text + def animate(i): + i = min(i, len(sim_hist['plan_x'])-1) + plan_xu, gt_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points = subsample_history(sim_hist, i, offset_x) + + ax_lane.set_offsets(lanes[-1]) + ax_gt_agent_traj, ax_gt_agent_dots = plot_traj(ax, gt_neighbors.transpose((1, 0, 2)), None, label='gt', c=gt_color) + ax_gt_pred_traj, ax_gt_pred_dots = plot_traj(ax, gt_pred.transpose((1, 0, 2)), None, label='gt pred', c=gt_color) + ax_ego_traj, ax_ego_dots = plot_traj(ax, gt_ego, None, label='ego', c=gt_color) + # ax_gt_agent_traj.set_data(gt_neighbors.transpose((1, 0, 2))[..., 0], gt_neighbors.transpose(1, 0, 2)[..., 1]) + # ax_gt_agent_dots.set_offsets(gt_neighbors.transpose((1, 0, 2))[-1]) + # ax_gt_pred_traj.set_data(gt_pred.transpose((1, 0, 2))[..., 0], gt_neighbors.transpose(1, 0, 2)[..., 1]) + # ax_gt_pred_dots.set_offsets(gt_pred.transpose((1, 0, 2))[-1]) + # ax_ego_traj.set_data(gt_ego[..., 0], gt_ego[..., 1]) + # ax_ego_dots.set_offsets(gt_ego[-1]) + + ax_text.set_text(f"Step {i}") + return ax_lane, ax_gt_agent_traj, ax_gt_agent_dots, ax_gt_pred_traj, ax_gt_pred_dots, ax_ego_traj, ax_ego_dots, ax_ego_log_traj, ax_ego_log_dots, ax_text + + anim = FuncAnimation(fig, animate, init_func=init, + frames=len(sim_hist['plan_x'])+2, interval=60, blit=False) + plt.show() + plt.pause(1.0) + print("anim done") + return anim + + + + for i in range(len(sim_hist['plan_x'])): + fig, ax = anim_background() + plan_xu, gt_ego, log_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points = subsample_history(sim_hist, i, offset_x) + plot_sim_state(fig, ax, plan_xu, gt_ego, log_ego, gt_neighbors, gt_pred, mus, logp, goal, lanes, lane_points) + plt.show() + + # Animate planning process + if animate: + + # fig, ax = anim_background() + # anim1 = animate_plan_iters(fig, ax, nopred_plan_iters, label='nopred plan', c=plan_nopred_color) + # output['anim_iters_nopred'].append(anim1) + + # fig, ax = anim_background() + # anim2 = animate_plan_iters(fig, ax, gt_plan_iters, label='gt plan', c=plan_gt_color) + # output['anim_iters_gt'].append(anim2) + + fig, ax = anim_background() + anim2 = animate_plan_iters(fig, ax, sim_hist) + output['anim_closed_loop'].append(anim2) + anim2.save('./cache/closed-loop' + '.gif',writer='imagemagick', fps=2) + + # anim2.save(save_plot_paths[batch_i] + '-gt_plan.gif',writer='imagemagick', fps=2) + # plt.show() + # plt.pause(1.0) + + return output diff --git a/patches/trajdata_vectorize.patch b/patches/trajdata_vectorize.patch new file mode 100644 index 0000000..529585b --- /dev/null +++ b/patches/trajdata_vectorize.patch @@ -0,0 +1,7097 @@ +diff --git a/.gitignore b/.gitignore +index 1e55dd9..1ad9436 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -157,3 +157,4 @@ cython_debug/ + # option (not recommended) you can uncomment the following to ignore the entire idea folder. + #.idea/ + ++tests/Drivesim_scene_generation.html +diff --git a/CITATION.cff b/CITATION.cff +index e793eb6..dbf28a0 100644 +--- a/CITATION.cff ++++ b/CITATION.cff +@@ -5,23 +5,7 @@ authors: + given-names: "Boris" + orcid: "https://orcid.org/0000-0002-8698-202X" + title: "trajdata: A unified interface to many trajectory forecasting datasets" +-version: 1.3.3 ++version: 1.0.3 + doi: 10.5281/zenodo.6671548 +-date-released: 2023-08-22 +-url: "https://github.com/nvr-avg/trajdata" +-preferred-citation: +- type: conference-paper +- authors: +- - family-names: "Ivanovic" +- given-names: "Boris" +- orcid: "https://orcid.org/0000-0002-8698-202X" +- - family-names: "Song" +- given-names: "Guanyu" +- - family-names: "Gilitschenski" +- given-names: "Igor" +- - family-names: "Pavone" +- given-names: "Marco" +- journal: "Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks" +- month: 12 +- title: "trajdata: A Unified Interface to Multiple Human Trajectory Datasets" +- year: 2023 +\ No newline at end of file ++date-released: 2022-06-20 ++url: "https://github.com/nvr-avg/trajdata" +\ No newline at end of file +diff --git a/README.md b/README.md +index 774d6d6..a715ebc 100644 +--- a/README.md ++++ b/README.md +@@ -1,4 +1,4 @@ +-# trajdata: A Unified Interface to Multiple Human Trajectory Datasets ++# Unified Trajectory Data Loader + + [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +@@ -6,10 +6,6 @@ + [![DOI](https://zenodo.org/badge/488789438.svg)](https://zenodo.org/badge/latestdoi/488789438) + [![PyPI version](https://badge.fury.io/py/trajdata.svg)](https://badge.fury.io/py/trajdata) + +-### Announcements +- +-**Sept 2023**: [Our paper about trajdata](https://arxiv.org/abs/2307.13924) has been accepted to the NeurIPS 2023 Datasets and Benchmarks Track! +- + ## Installation + + The easiest way to install trajdata is through PyPI with +@@ -211,20 +207,6 @@ for t in range(1, sim_scene.scene.length_timesteps): + + `examples/sim_example.py` contains a more comprehensive example which initializes a simulation from a scene in the nuScenes mini dataset, steps through it by replaying agents' GT motions, and computes metrics based on scene statistics (e.g., displacement error from the original GT data, velocity/acceleration/jerk histograms). + +-## Citation +- +-If you use this software, please cite it as follows: +-``` +-@Inproceedings{ivanovic2023trajdata, +- author = {Ivanovic, Boris and Song, Guanyu and Gilitschenski, Igor and Pavone, Marco}, +- title = {{trajdata}: A Unified Interface to Multiple Human Trajectory Datasets}, +- booktitle = {{Proceedings of the Neural Information Processing Systems (NeurIPS) Track on Datasets and Benchmarks}}, +- month = dec, +- year = {2023}, +- address = {New Orleans, USA}, +- url = {https://arxiv.org/abs/2307.13924} +-} +-``` +- + ## TODO + - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. ++- Add more examples to the README. +diff --git a/copy_to_public.sh b/copy_to_public.sh +new file mode 100755 +index 0000000..dba3133 +--- /dev/null ++++ b/copy_to_public.sh +@@ -0,0 +1 @@ ++rsync -ravh --progress --exclude ".gitlab*" --exclude "public/" --exclude "opendrive" --exclude "copy_to_public.sh" --exclude "*.pyc" --exclude "ngc/" --exclude "*.egg-info/" --exclude "__pycache__/" ./* public/ +\ No newline at end of file +diff --git a/examples/map_api_example.py b/examples/map_api_example.py +index 759b2e0..75077ce 100644 +--- a/examples/map_api_example.py ++++ b/examples/map_api_example.py +@@ -4,6 +4,7 @@ from typing import Dict, List, Optional + + import matplotlib.pyplot as plt + import numpy as np ++import os + + from trajdata import MapAPI, VectorMap + from trajdata.caching.df_cache import DataFrameCache +@@ -23,7 +24,7 @@ def load_random_scene(cache_path: Path, env_name: str, scene_dt: float) -> Scene + + + def main(): +- cache_path = Path("~/.unified_data_cache").expanduser() ++ cache_path = Path(os.environ["TRAJDATA_CACHE_DIR"]).expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. +diff --git a/ngc/Dockerfile b/ngc/Dockerfile +new file mode 100644 +index 0000000..c5841bc +--- /dev/null ++++ b/ngc/Dockerfile +@@ -0,0 +1,27 @@ ++FROM nvcr.io/nvidia/pytorch:21.03-py3 ++ARG PYTHON_VERSION=3.8 ++ ++RUN apt-get update ++RUN apt-get install htop -y ++RUN apt-get install screen -y ++RUN apt-get install psmisc -y ++ ++RUN pip install --upgrade pip ++ ++RUN pip install \ ++numpy==1.19 \ ++tqdm==4.62 \ ++matplotlib==3.5 \ ++dill==0.3.4 \ ++pandas==1.4.1 \ ++pyarrow==7.0.0 \ ++nuscenes-devkit==1.1.9 \ ++l5kit==1.5.0 \ ++black==22.1.0 \ ++isort==5.10.1 \ ++pytest==7.1.1 \ ++pytest-xdist==2.5.0 \ ++zarr==2.11.0 \ ++kornia==0.6.4 ++ ++WORKDIR /workspace/trajdata +diff --git a/ngc/copy_to_ngc.sh b/ngc/copy_to_ngc.sh +new file mode 100755 +index 0000000..c42315f +--- /dev/null ++++ b/ngc/copy_to_ngc.sh +@@ -0,0 +1,4 @@ ++#!/bin/bash ++ ++cd ../.. ++rsync -ravh --progress --exclude="*.pyc" --exclude=".git" --exclude="__pycache__" --exclude="*.egg-info" --exclude=".pytest_cache" trajdata/ ngc-trajdata/ +diff --git a/ngc/jupyter_lab.json b/ngc/jupyter_lab.json +new file mode 100644 +index 0000000..7d43c0c +--- /dev/null ++++ b/ngc/jupyter_lab.json +@@ -0,0 +1,49 @@ ++{ ++ "userLabels":[ ++ ++ ], ++ "aceId":257, ++ "aceInstance":"dgx1v.16g.8.norm", ++ "dockerImageName":"nvidian/nvr-av/trajdata:1.0", ++ "aceName":"nv-us-west-2", ++ "systemLabels":[ ++ ++ ], ++ "datasetMounts":[ ++ { ++ "containerMountPoint":"/workspace/lyft", ++ "id":90893 ++ }, ++ { ++ "containerMountPoint":"/workspace/nuScenes", ++ "id":78251 ++ } ++ ], ++ "workspaceMounts":[ ++ { ++ "containerMountPoint":"/workspace/trajdata", ++ "id":"aWUsWZ_5SA6caFr30uWmmw", ++ "mountMode":"RW" ++ }, ++ { ++ "containerMountPoint":"/workspace/trajdata_cache", ++ "id":"BgHSNdSaR0ywsc-UY-xm3w", ++ "mountMode":"RW" ++ } ++ ], ++ "replicaCount":1, ++ "publishedContainerPorts":[ ++ 8888 ++ ], ++ "reservedLabels":[ ++ ++ ], ++ "name":"ml-model.notamodel-trajdata-processing", ++ "command":"pip install -e .; jupyter lab --ip=0.0.0.0 --allow-root --no-browser --NotebookApp.token='' --notebook-dir=/ --NotebookApp.allow_origin='*' & date; nvidia-smi; sleep 1d", ++ "runPolicy":{ ++ "minTimesliceSeconds":1, ++ "totalRuntimeSeconds":86400, ++ "preemptClass":"RUNONCE" ++ }, ++ "resultContainerMountPoint":"/results" ++} +diff --git a/ngc/ngc_test.py b/ngc/ngc_test.py +new file mode 100644 +index 0000000..c63312e +--- /dev/null ++++ b/ngc/ngc_test.py +@@ -0,0 +1,57 @@ ++import os ++from collections import defaultdict ++ ++from torch.utils.data import DataLoader ++from tqdm import tqdm ++ ++from trajdata import AgentBatch, AgentType, UnifiedDataset ++from trajdata.augmentation import NoiseHistories ++from trajdata.visualization.vis import plot_agent_batch ++ ++ ++# @profile ++def main(): ++ noise_hists = NoiseHistories() ++ ++ dataset = UnifiedDataset( ++ desired_data=["nusc"], ++ centric="agent", ++ # desired_dt=0.1, ++ history_sec=(0.1, 1.5), ++ future_sec=(0.1, 5.0), ++ only_types=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ++ agent_interaction_distances=defaultdict(lambda: 40.0), ++ incl_robot_future=True, ++ incl_raster_map=True, ++ raster_map_params={ ++ "px_per_m": 2, ++ "map_size_px": 224, ++ "offset_frac_xy": (-0.5, 0.0), ++ }, ++ # augmentations=[noise_hists], ++ data_dirs={ ++ "nusc": "/workspace/datasets/nuScenes", ++ }, ++ cache_location="/workspace/unified_data_cache", ++ num_workers=os.cpu_count(), ++ # verbose=True, ++ ) ++ ++ print(f"# Data Samples: {len(dataset):,}") ++ ++ dataloader = DataLoader( ++ dataset, ++ batch_size=64, ++ shuffle=True, ++ collate_fn=dataset.get_collate_fn(), ++ num_workers=os.cpu_count(), ++ ) ++ ++ batch: AgentBatch ++ for batch in tqdm(dataloader): ++ pass ++ # plot_agent_batch(batch, batch_idx=0) ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/ngc/preprocess_data_ngc.py b/ngc/preprocess_data_ngc.py +new file mode 100644 +index 0000000..920f439 +--- /dev/null ++++ b/ngc/preprocess_data_ngc.py +@@ -0,0 +1,33 @@ ++from trajdata import UnifiedDataset ++ ++ ++def main(): ++ dataset = UnifiedDataset( ++ desired_data=[ ++ "nusc_trainval", ++ "nusc_mini", ++ "lyft_sample", ++ "lyft_train", ++ # "lyft_train_full", ++ "lyft_val", ++ ], ++ data_dirs={ ++ "nusc_trainval": "/workspace/nuScenes", ++ "nusc_mini": "/workspace/nuScenes", ++ "lyft_sample": "/workspace/lyft/lyft_prediction/scenes/sample.zarr", ++ "lyft_train": "/workspace/lyft/lyft_prediction/scenes/train.zarr", ++ # "lyft_train_full": "/workspace/lyft/lyft_prediction/scenes/train_full.zarr", ++ "lyft_val": "/workspace/lyft/lyft_prediction/scenes/validate.zarr", ++ }, ++ cache_location="/workspace/trajdata_cache", ++ rebuild_cache=True, ++ rebuild_maps=True, ++ num_workers=64, ++ verbose=True, ++ ) ++ ++ print(f"Total Data Samples: {len(dataset):,}") ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/requirements.txt b/requirements.txt +new file mode 100644 +index 0000000..97d46f9 +--- /dev/null ++++ b/requirements.txt +@@ -0,0 +1,26 @@ ++numpy ++tqdm ++matplotlib==3.3.4 ++dill ++pandas ++seaborn ++pyarrow ++torch ++zarr ++kornia ++intervaltree ++ ++# nuScenes devkit ++nuscenes-devkit==1.1.9 ++ ++# Lyft Level 5 devkit ++protobuf==3.19.4 ++l5kit==1.5.0 ++ ++# Development ++black ++isort ++pytest ++pytest-xdist ++twine ++build +diff --git a/setup.cfg b/setup.cfg +new file mode 100644 +index 0000000..01016fc +--- /dev/null ++++ b/setup.cfg +@@ -0,0 +1,50 @@ ++[metadata] ++name = trajdata ++version = 1.1.0 ++author = Boris Ivanovic ++author_email = bivanovic@nvidia.com ++description = A unified interface to many trajectory forecasting datasets. ++long_description = file: README.md ++long_description_content_type = text/markdown ++# license = Apache License 2.0 ++url = https://github.com/nvr-avg/trajdata ++classifiers = ++ Development Status :: 3 - Alpha ++ Intended Audience :: Developers ++ Programming Language :: Python :: 3.8 ++ License :: OSI Approved :: Apache Software License ++ ++[options] ++package_dir = ++ = src ++packages = find: ++python_requires = >=3.8 ++install_requires = ++ numpy>=1.19 ++ tqdm>=4.62 ++ matplotlib>=3.5 ++ dill>=0.3.4 ++ pandas>=1.4.1 ++ pyarrow>=7.0.0 ++ torch>=1.10.2 ++ zarr>=2.11.0 ++ kornia>=0.6.4 ++ seaborn>=0.12 ++ intervaltree ++ ++[options.packages.find] ++where = src ++ ++[options.extras_require] ++dev = ++ black ++ isort ++ pytest ++ pytest-xdist ++ twine ++ build ++nusc = ++ nuscenes-devkit==1.1.9 ++lyft = ++ protobuf==3.20.3 ++ l5kit==1.5.0 +diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py +index eb0efef..654977c 100644 +--- a/src/trajdata/caching/df_cache.py ++++ b/src/trajdata/caching/df_cache.py +@@ -30,6 +30,7 @@ from trajdata.data_structures.scene_metadata import Scene + from trajdata.data_structures.state import NP_STATE_TYPES, StateArray + from trajdata.maps.traffic_light_status import TrafficLightStatus + from trajdata.utils import arr_utils, df_utils, raster_utils, state_utils ++from trajdata.utils.scene_utils import is_integer_robust + + STATE_COLS: Final[List[str]] = ["x", "y", "z", "vx", "vy", "ax", "ay"] + EXTENT_COLS: Final[List[str]] = ["length", "width", "height"] +@@ -320,7 +321,7 @@ class DataFrameCache(SceneCache): + def _upsample_data( + self, new_index: pd.MultiIndex, upsample_dt_ratio: float, method: str + ) -> pd.DataFrame: +- upsample_dt_factor: int = int(upsample_dt_ratio) ++ upsample_dt_factor: int = int(round(upsample_dt_ratio)) + + interpolated_df: pd.DataFrame = pd.DataFrame( + index=new_index, columns=self.scene_data_df.columns +@@ -353,7 +354,7 @@ class DataFrameCache(SceneCache): + def _downsample_data( + self, new_index: pd.MultiIndex, downsample_dt_ratio: float + ) -> pd.DataFrame: +- downsample_dt_factor: int = int(downsample_dt_ratio) ++ downsample_dt_factor: int = int(round(downsample_dt_ratio)) + + subsample_index: pd.MultiIndex = new_index.set_levels( + new_index.levels[1] * downsample_dt_factor, level=1 +@@ -368,7 +369,8 @@ class DataFrameCache(SceneCache): + def interpolate_data(self, desired_dt: float, method: str = "linear") -> None: + upsample_dt_ratio: float = self.scene.env_metadata.dt / desired_dt + downsample_dt_ratio: float = desired_dt / self.scene.env_metadata.dt +- if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): ++ ++ if not is_integer_robust(upsample_dt_ratio) and not is_integer_robust(downsample_dt_ratio): + raise ValueError( + f"{str(self.scene)}'s dt of {self.scene.dt}s " + f"is not integer divisible by the desired dt {desired_dt}s." +diff --git a/src/trajdata/data_structures/agent.py b/src/trajdata/data_structures/agent.py +index 7d21217..0687d8e 100644 +--- a/src/trajdata/data_structures/agent.py ++++ b/src/trajdata/data_structures/agent.py +@@ -11,6 +11,7 @@ class AgentType(IntEnum): + PEDESTRIAN = 2 + BICYCLE = 3 + MOTORCYCLE = 4 ++ STATIC = 5 + + + class Extent: +diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py +index c93fb8b..1903557 100644 +--- a/src/trajdata/data_structures/batch.py ++++ b/src/trajdata/data_structures/batch.py +@@ -1,6 +1,6 @@ + from __future__ import annotations + +-from dataclasses import dataclass ++from dataclasses import dataclass, replace + from typing import Dict, List, Optional, Union + + import torch +@@ -9,7 +9,7 @@ from torch import Tensor + from trajdata.data_structures.agent import AgentType + from trajdata.data_structures.state import StateTensor + from trajdata.maps import VectorMap +-from trajdata.utils.arr_utils import PadDirection ++from trajdata.utils.arr_utils import PadDirection, batch_nd_transform_xyvvaahh_pt, roll_with_tensor, transform_xyh_torch + + + @dataclass +@@ -39,6 +39,10 @@ class AgentBatch: + map_names: Optional[List[str]] + maps: Optional[Tensor] + maps_resolution: Optional[Tensor] ++ lane_xyh: Optional[Tensor] ++ lane_adj: Optional[Tensor] ++ lane_ids: Optional[List[List[str]]] ++ lane_mask: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] + rasters_from_world_tf: Optional[Tensor] + agents_from_world_tf: Tensor +@@ -141,12 +145,16 @@ class AgentBatch: + maps_resolution=_filter(self.maps_resolution) + if self.maps_resolution is not None + else None, +- vector_maps=_filter(self.vector_maps) ++ vector_maps=_filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None, ++ lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, ++ lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, ++ lane_ids=self.lane_ids, ++ lane_mask=_filter(self.lane_mask) if self.lane_mask is not None else None, + agents_from_world_tf=_filter(self.agents_from_world_tf), + scene_ids=_filter_tensor_or_list(self.scene_ids), + history_pad_dir=self.history_pad_dir, +@@ -155,6 +163,48 @@ class AgentBatch: + }, + ) + ++ def to_scene_batch(self, agent_ind: int) -> SceneBatch: ++ """ ++ Converts AgentBatch to SeceneBatch by combining neighbors and agent. ++ ++ The agent of AgentBatch will be treated as if it was the last neighbor. ++ self.extras will be simply copied over, any custom conversion must be ++ implemented externally. ++ """ ++ ++ batch_size = self.neigh_hist.shape[0] ++ num_neigh = self.neigh_hist.shape[1] ++ ++ combine = lambda neigh, agent: torch.cat((neigh, agent.unsqueeze(0)), dim=0) ++ combine_list = lambda neigh, agent: neigh + [agent] ++ ++ return SceneBatch( ++ data_idx=self.data_idx, ++ scene_ts=self.scene_ts, ++ dt=self.dt, ++ num_agents=self.num_neigh + 1, ++ agent_type=combine(self.neigh_types, self.agent_type), ++ centered_agent_state=self.curr_agent_state, # TODO this is not actually the agent but the `global` coordinate frame ++ agent_names=combine_list(["UNKNOWN" for _ in range(num_neigh)], self.agent_name), ++ agent_hist=combine(self.neigh_hist, self.agent_hist), ++ agent_hist_extent=combine(self.neigh_hist_extents, self.agent_hist_extent), ++ agent_hist_len=combine(self.neigh_hist_len, self.agent_hist_len), ++ agent_fut=combine(self.neigh_fut, self.agent_fut), ++ agent_fut_extent=combine(self.neigh_fut_extents, self.agent_fut_extent), ++ agent_fut_len=combine(self.neigh_fut_len, self.agent_fut_len), ++ robot_fut=self.robot_fut, ++ robot_fut_len=self.robot_fut_len, ++ map_names=self.map_names, # TODO ++ maps=self.maps, ++ maps_resolution=self.maps_resolution, ++ vector_maps=self.vector_maps, ++ rasters_from_world_tf=self.rasters_from_world_tf, ++ centered_agent_from_world_tf=self.agents_from_world_tf, ++ centered_world_from_agent_tf=torch.linalg.inv(self.agents_from_world_tf), ++ scene_ids=self.scene_ids, ++ history_pad_dir=self.history_pad_dir, ++ extras=self.extras, ++ ) + + @dataclass + class SceneBatch: +@@ -164,7 +214,8 @@ class SceneBatch: + num_agents: Tensor + agent_type: Tensor + centered_agent_state: StateTensor +- agent_names: List[str] ++ agent_names: List[List[str]] ++ track_ids:Optional[List[List[str]]] + agent_hist: StateTensor + agent_hist_extent: Tensor + agent_hist_len: Tensor +@@ -177,6 +228,10 @@ class SceneBatch: + maps: Optional[Tensor] + maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] ++ lane_xyh: Optional[Tensor] ++ lane_adj: Optional[Tensor] ++ lane_ids: Optional(List[List[str]]) ++ lane_mask: Optional[Tensor] + rasters_from_world_tf: Optional[Tensor] + centered_agent_from_world_tf: Tensor + centered_world_from_agent_tf: Tensor +@@ -186,12 +241,19 @@ class SceneBatch: + + def to(self, device) -> None: + excl_vals = { ++ "num_agents", + "agent_names", ++ "track_ids", ++ "agent_type", ++ "agent_hist_len", ++ "agent_fut_len", ++ "robot_fut_len", + "map_names", + "vector_maps", + "history_pad_dir", + "scene_ids", + "extras", ++ "lane_ids", + } + + for val in vars(self).keys(): +@@ -205,6 +267,31 @@ class SceneBatch: + self.extras[key] = val.__to__(device, non_blocking=True) + else: + self.extras[key] = val.to(device, non_blocking=True) ++ return self ++ ++ def astype(self, dtype) -> None: ++ new_obj = replace(self) ++ excl_vals = { ++ "num_agents", ++ "agent_names", ++ "track_ids", ++ "agent_type", ++ "agent_hist_len", ++ "agent_fut_len", ++ "robot_fut_len", ++ "map_names", ++ "vector_maps", ++ "history_pad_dir", ++ "scene_ids", ++ "extras", ++ "lane_ids", ++ } ++ ++ for val in vars(self).keys(): ++ tensor_val = getattr(self, val) ++ if val not in excl_vals and tensor_val is not None: ++ setattr(new_obj, val, tensor_val.type(dtype)) ++ return new_obj + + def agent_types(self) -> List[AgentType]: + unique_types: Tensor = torch.unique(self.agent_type) +@@ -214,13 +301,31 @@ class SceneBatch: + if unique_type >= 0 + ] + +- def for_agent_type(self, agent_type: AgentType) -> SceneBatch: +- match_type = self.agent_type == agent_type +- return self.filter_batch(match_type) ++ def copy(self): ++ # Shallow copy ++ return replace(self) ++ ++ def convert_pad_direction(self, pad_dir: PadDirection) -> SceneBatch: ++ if self.history_pad_dir == pad_dir: ++ return self ++ batch: SceneBatch = self.copy() ++ if self.history_pad_dir == PadDirection.BEFORE: ++ # n, n, -2 , -1, 0 --> -2, -1, 0, n, n ++ shifts = batch.agent_hist_len ++ else: ++ # -2, -1, 0, n, n --> n, n, -2 , -1, 0 ++ shifts = -batch.agent_hist_len ++ batch.agent_hist = roll_with_tensor(batch.agent_hist, shifts, dim=-2) ++ batch.agent_hist_extent = roll_with_tensor(batch.agent_hist_extent, shifts, dim=-2) ++ batch.history_pad_dir = pad_dir ++ return batch + +- def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: ++ def filter_batch(self, filter_mask: torch.Tensor) -> SceneBatch: + """Build a new batch with elements for which filter_mask[i] == True.""" + ++ if filter_mask.ndim != 1: ++ raise ValueError("Expected 1d filter mask.") ++ + # Some of the tensors might be on different devices, so we define some convenience functions + # to make sure the filter_mask is always on the same device as the tensor we are indexing. + filter_mask_dict = {} +@@ -229,7 +334,8 @@ class SceneBatch: + self.agent_hist.device + ) + +- _filter = lambda tensor: tensor[filter_mask_dict[str(tensor.device)]] ++ # Use tensor.__class__ to keep TensorState. This might ++ _filter = lambda tensor: tensor.__class__(tensor[filter_mask_dict[str(tensor.device)]]) + _filter_tensor_or_list = lambda tensor_or_list: ( + _filter(tensor_or_list) + if isinstance(tensor_or_list, torch.Tensor) +@@ -248,6 +354,8 @@ class SceneBatch: + dt=_filter(self.dt), + num_agents=_filter(self.num_agents), + agent_type=_filter(self.agent_type), ++ agent_names=_filter_tensor_or_list(self.agent_names), ++ track_ids=_filter_tensor_or_list(self.track_ids), + centered_agent_state=_filter(self.centered_agent_state), + agent_hist=_filter(self.agent_hist), + agent_hist_extent=_filter(self.agent_hist_extent), +@@ -266,9 +374,13 @@ class SceneBatch: + maps_resolution=_filter(self.maps_resolution) + if self.maps_resolution is not None + else None, +- vector_maps=_filter(self.vector_maps) ++ vector_maps=_filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None, ++ lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, ++ lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, ++ lane_ids = self.lane_ids, ++ lane_mask = _filter(self.lane_mask) if self.lane_mask is not None else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None, +@@ -277,7 +389,7 @@ class SceneBatch: + scene_ids=_filter_tensor_or_list(self.scene_ids), + history_pad_dir=self.history_pad_dir, + extras={ +- key: _filter_tensor_or_list(val, filter_mask) ++ key: _filter_tensor_or_list(val) + for key, val in self.extras.items() + }, + ) +@@ -321,31 +433,68 @@ class SceneBatch: + scene_ts=self.scene_ts, + dt=self.dt, + agent_name=index_agent_list(self.agent_names), ++ track_ids=index_agent_list(self.track_ids), + agent_type=index_agent(self.agent_type), + curr_agent_state=self.centered_agent_state, # TODO this is not actually the agent but the `global` coordinate frame +- agent_hist=index_agent(self.agent_hist), ++ agent_hist=StateTensor.from_array(index_agent(self.agent_hist), self.agent_hist._format), + agent_hist_extent=index_agent(self.agent_hist_extent), + agent_hist_len=index_agent(self.agent_hist_len), +- agent_fut=index_agent(self.agent_fut), ++ agent_fut=StateTensor.from_array(index_agent(self.agent_fut), self.agent_fut._format), + agent_fut_extent=index_agent(self.agent_fut_extent), + agent_fut_len=index_agent(self.agent_fut_len), + num_neigh=self.num_agents - 1, + neigh_types=index_neighbors(self.agent_type), +- neigh_hist=index_neighbors(self.agent_hist), ++ neigh_hist=StateTensor.from_array(index_neighbors(self.agent_hist), self.agent_hist._format), + neigh_hist_extents=index_neighbors(self.agent_hist_extent), + neigh_hist_len=index_neighbors(self.agent_hist_len), +- neigh_fut=index_neighbors(self.agent_fut), ++ neigh_fut=StateTensor.from_array(index_neighbors(self.agent_fut), self.agent_fut._format), + neigh_fut_extents=index_neighbors(self.agent_fut_extent), + neigh_fut_len=index_neighbors(self.agent_fut_len), + robot_fut=self.robot_fut, + robot_fut_len=self.robot_fut_len, +- map_names=index_agent_list(self.map_names), +- maps=index_agent(self.maps), +- vector_maps=index_agent(self.vector_maps), +- maps_resolution=index_agent(self.maps_resolution), +- rasters_from_world_tf=index_agent(self.rasters_from_world_tf), ++ map_names=self.map_names, ++ maps=self.maps, ++ vector_maps=self.vector_maps, ++ maps_resolution=self.maps_resolution, ++ rasters_from_world_tf=self.rasters_from_world_tf, + agents_from_world_tf=self.centered_agent_from_world_tf, + scene_ids=self.scene_ids, + history_pad_dir=self.history_pad_dir, + extras=self.extras, + ) ++ ++ def apply_transform(self, tf: torch.Tensor, dtype: Optional[torch.dtype] = None) -> SceneBatch: ++ """ ++ Applies a transformation matrix to all coordinates stored in the SceneBatch. ++ ++ Returns a shallow copy, only coordinate fields are replaced. ++ self.extras will be simply copied over (shallow copy), any custom conversion must be ++ implemented externally. ++ """ ++ assert tf.ndim == 3 # b, 3, 3 ++ assert tf.shape[-1] == 3 and tf.shape[-1] == 3 ++ assert tf.dtype == torch.double # tf should be double precision, otherwise we have large numerical errors ++ if dtype is None: ++ dtype = self.agent_hist.dtype ++ ++ # Shallow copy ++ batch: SceneBatch = replace(self) ++ ++ # TODO (pkarkus) support generic format ++ assert batch.agent_hist._format == "x,y,xd,yd,xdd,ydd,s,c" ++ assert batch.agent_fut._format == "x,y,xd,yd,xdd,ydd,s,c" ++ state_class = batch.agent_hist.__class__ ++ ++ # Transforms ++ batch.agent_hist = state_class(batch_nd_transform_xyvvaahh_pt(batch.agent_hist.double(), tf).type(dtype)) ++ batch.agent_fut = state_class(batch_nd_transform_xyvvaahh_pt(batch.agent_fut.double(), tf).type(dtype)) ++ batch.rasters_from_world_tf = tf.unsqueeze(1) @ batch.rasters_from_world_tf if batch.rasters_from_world_tf is not None else None ++ batch.centered_agent_from_world_tf = tf @ batch.centered_agent_from_world_tf ++ centered_world_from_agent_tf = torch.linalg.inv(batch.centered_agent_from_world_tf) ++ if batch.lane_xyh is not None: ++ batch.lane_xyh = transform_xyh_torch(batch.lane_xyh.double(), tf).type(dtype) ++ # sanity check ++ assert torch.isclose(batch.centered_world_from_agent_tf @ torch.linalg.inv(tf), centered_world_from_agent_tf, atol=1e-5).all() ++ batch.centered_world_from_agent_tf = centered_world_from_agent_tf ++ ++ return batch +diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py +index f18e34d..012f3b8 100644 +--- a/src/trajdata/data_structures/batch_element.py ++++ b/src/trajdata/data_structures/batch_element.py +@@ -10,6 +10,11 @@ from trajdata.data_structures.scene import SceneTime, SceneTimeAgent + from trajdata.data_structures.state import StateArray + from trajdata.maps import MapAPI, RasterizedMapPatch, VectorMap + from trajdata.utils.state_utils import convert_to_frame_state, transform_from_frame ++from trajdata.utils.arr_utils import transform_xyh_np, get_close_lanes ++ ++from trajdata.utils.map_utils import LaneSegRelation ++ ++ + + + class AgentBatchElement: +@@ -158,6 +163,20 @@ class AgentBatchElement: + else None, + **vector_map_params if vector_map_params is not None else None, + ) ++ if vector_map_params.get("calc_lane_graph", False): ++ # not tested ++ ego_xyh = np.concatenate([self.curr_agent_state_np.position, self.curr_agent_state_np.heading]) ++ num_pts = vector_map_params.get("num_lane_pts", 30) ++ max_num_lanes = vector_map_params.get("max_num_lanes",20) ++ remove_single_successor = vector_map_params.get("remove_single_successor",False) ++ radius = vector_map_params.get("radius", 100) ++ self.num_lanes,self.lane_xyh,self.lane_adj,self.lane_ids = gen_lane_graph(self.vec_map,ego_xyh,self.agent_from_world_tf,num_pts,max_num_lanes,radius,remove_single_successor) ++ ++ else: ++ self.lane_xyh = None ++ self.lane_adj = None ++ self.lane_ids=list() ++ self.num_lanes = 0 + + self.scene_id = scene_time_agent.scene.name + +@@ -404,7 +423,7 @@ class SceneBatchElement: + nearby_agents, self.agent_types_np = self.get_nearby_agents( + scene_time, self.centered_agent, distance_limit + ) +- ++ self.agents = nearby_agents + self.num_agents = len(nearby_agents) + self.agent_names = [agent.name for agent in nearby_agents] + ( +@@ -440,7 +459,22 @@ class SceneBatchElement: + self.cache if self.cache.is_traffic_light_data_cached() else None, + **vector_map_params if vector_map_params is not None else None, + ) +- ++ if vector_map_params.get("calc_lane_graph", False): ++ # not tested ++ ego_xyh = np.concatenate([self.centered_agent_state_np.position, self.centered_agent_state_np.heading]) ++ num_pts = vector_map_params.get("num_lane_pts", 30) ++ max_num_lanes = vector_map_params.get("max_num_lanes",20) ++ self.num_lanes,self.lane_xyh,self.lane_adj,self.lane_ids = gen_lane_graph(self.vec_map,ego_xyh,self.centered_agent_from_world_tf,num_pts,max_num_lanes) ++ ++ else: ++ self.lane_xyh = None ++ self.lane_adj = None ++ self.lane_ids = list() ++ self.num_lanes = 0 ++ ++ ++ ++ + self.scene_id = scene_time.scene.name + + ### ROBOT DATA ### +@@ -585,6 +619,73 @@ class SceneBatchElement: + ).view(self.cache.obs_type) + return robot_curr_and_fut_np + ++ ++def gen_lane_graph(vec_map,ego_xyh,agent_from_world,num_pts=20,max_num_lanes=15,radius=150,remove_single_successor=True): ++ close_lanes,dis = get_close_lanes(radius,ego_xyh,vec_map,num_pts) ++ lanes_by_id = {lane.id:lane for lane in close_lanes} ++ dis_by_id = {lane.id:dis[i] for i,lane in enumerate(close_lanes)} ++ if remove_single_successor: ++ for lane in close_lanes: ++ ++ while len(lane.next_lanes) == 1: ++ # if there are more than one succeeding lanes, then we abort the merging ++ next_id = list(lane.next_lanes)[0] ++ ++ if next_id in lanes_by_id: ++ next_lane = lanes_by_id[next_id] ++ shared_next = False ++ for id in next_lane.prev_lanes: ++ if id != lane.id and id in lanes_by_id: ++ shared_next = True ++ break ++ if shared_next: ++ # if the next lane shares two prev lanes in the close_lanes, then we abort the merging ++ break ++ lane.combine_next(lanes_by_id[next_id]) ++ lanes_by_id.pop(next_id) ++ else: ++ break ++ close_lanes = list(lanes_by_id.values()) ++ dis = np.array([dis_by_id[lane.id] for lane in close_lanes]) ++ num_lanes = len(close_lanes) ++ if num_lanes > max_num_lanes: ++ idx = dis.argsort()[:max_num_lanes] ++ close_lanes = [lane for i,lane in enumerate(close_lanes) if i in idx] ++ num_lanes = max_num_lanes ++ ++ if num_lanes >0: ++ lane_xyh = list() ++ lane_adj = np.zeros([len(close_lanes), len(close_lanes)],dtype=np.int32) ++ lane_ids = [lane.id for lane in close_lanes] ++ ++ for i,lane in enumerate(close_lanes): ++ center = lane.center.interpolate(num_pts).points[:,[0,1,3]] ++ center_local = transform_xyh_np(center, agent_from_world[None]) ++ lane_xyh.append(center_local) ++ # construct lane adjacency matrix ++ for adj_lane_id in lane.next_lanes: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.NEXT.value ++ ++ for adj_lane_id in lane.prev_lanes: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.PREV.value ++ ++ for adj_lane_id in lane.adj_lanes_left: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.LEFT.value ++ ++ for adj_lane_id in lane.adj_lanes_right: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.RIGHT.value ++ lane_xyh = np.stack(lane_xyh, axis=0) ++ lane_xyh = lane_xyh ++ lane_adj = lane_adj ++ else: ++ lane_xyh = np.zeros([0,num_pts,3]) ++ lane_adj = np.zeros([0,0]) ++ lane_ids = list() ++ return num_lanes,lane_xyh,lane_adj,lane_ids + + def is_agent_stationary(cache: SceneCache, agent_info: AgentMetadata) -> bool: + # Agent is considered stationary if it moves less than 1m between the first and last valid timestep. +diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py +index 1b4d2be..2b911d6 100644 +--- a/src/trajdata/data_structures/collation.py ++++ b/src/trajdata/data_structures/collation.py +@@ -11,8 +11,9 @@ from torch.nn.utils.rnn import pad_sequence + from trajdata.augmentation import BatchAugmentation + from trajdata.data_structures.batch import AgentBatch, SceneBatch + from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement ++from trajdata.maps import VectorMap, RasterizedMapPatch ++from trajdata.utils.map_utils import batch_transform_raster_maps + from trajdata.data_structures.state import TORCH_STATE_TYPES +-from trajdata.maps import VectorMap + from trajdata.utils import arr_utils + + +@@ -34,10 +35,26 @@ def _collate_data(elems): + else: + return torch.as_tensor(np.stack(elems)) + ++def _collate_lane_graph(elems): ++ num_lanes = [elem.num_lanes for elem in elems] ++ bs = len(elems) ++ M = max(num_lanes) ++ lane_xyh = np.zeros([bs,M,*elems[0].lane_xyh.shape[-2:]]) ++ lane_adj = np.zeros([bs,M,M],dtype=int) ++ lane_ids = list() ++ lane_mask = np.zeros([bs,M],dtype=int) ++ for i,elem in enumerate(elems): ++ lane_xyh[i,:num_lanes[i]] = elem.lane_xyh ++ lane_adj[i,:num_lanes[i],:num_lanes[i]] = elem.lane_adj ++ lane_ids.append(elem.lane_ids) ++ lane_mask[i,:num_lanes[i]] = 1 ++ return torch.as_tensor(lane_xyh),torch.as_tensor(lane_adj), torch.as_tensor(lane_mask), lane_ids + + def raster_map_collate_fn_agent( + batch_elems: List[AgentBatchElement], + ): ++ # TODO(pkarkus) refactor with batch_rotate_raster_maps ++ + if batch_elems[0].map_patch is None: + return None, None, None, None + +@@ -181,89 +198,48 @@ def raster_map_collate_fn_scene( + batch_elems: List[SceneBatchElement], + max_agent_num: Optional[int] = None, + pad_value: Any = np.nan, +-) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: ++) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + + if batch_elems[0].map_patches is None: + return None, None, None, None + +- patch_size: int = batch_elems[0].map_patches[0].crop_size +- assert all( +- batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems +- ) +- ++ # Collect map patches for all elements and agents into a flat list + map_names: List[str] = list() + num_agents: List[int] = list() +- agents_rasters_from_world_tfs: List[np.ndarray] = list() +- agents_patches: List[np.ndarray] = list() +- agents_rot_angles_list: List[float] = list() +- agents_res_list: List[float] = list() ++ map_patches: List[RasterizedMapPatch] = list() + + for elem in batch_elems: + map_names.append(elem.map_name) + num_agents.append(min(elem.num_agents, max_agent_num)) +- agents_rasters_from_world_tfs += [ +- x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] +- ] +- agents_patches += [x.data for x in elem.map_patches[:max_agent_num]] +- agents_rot_angles_list += [ +- x.rot_angle for x in elem.map_patches[:max_agent_num] +- ] +- agents_res_list += [x.resolution for x in elem.map_patches[:max_agent_num]] ++ map_patches += elem.map_patches[:max_agent_num] + +- patch_data: Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) +- agents_rot_angles: Tensor = torch.as_tensor( +- np.stack(agents_rot_angles_list), dtype=torch.float +- ) +- agents_rasters_from_world_tf: Tensor = torch.as_tensor( +- np.stack(agents_rasters_from_world_tfs), dtype=torch.float +- ) +- agents_resolution: Tensor = torch.as_tensor( +- np.stack(agents_res_list), dtype=torch.float ++ # Batch transform map patches and pad ++ ( ++ rot_crop_patches, ++ agents_resolution, ++ agents_rasters_from_world_tf ++ ) = batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches, num_agents, max_agent_num, pad_value + ) + +- patch_size_y, patch_size_x = patch_data.shape[-2:] +- center_y: int = patch_size_y // 2 +- center_x: int = patch_size_x // 2 +- half_extent: int = patch_size // 2 +- +- if torch.count_nonzero(agents_rot_angles) == 0: +- agents_rasters_from_world_tf = torch.bmm( +- torch.tensor( +- [ +- [ +- [1.0, 0.0, half_extent], +- [0.0, 1.0, half_extent], +- [0.0, 0.0, 1.0], +- ] +- ], +- dtype=agents_rasters_from_world_tf.dtype, +- device=agents_rasters_from_world_tf.device, +- ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), +- agents_rasters_from_world_tf, +- ) ++ return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + +- rot_crop_patches = patch_data +- else: +- agents_rasters_from_world_tf = torch.bmm( +- arr_utils.transform_matrices( +- -agents_rot_angles, +- torch.tensor([[half_extent, half_extent]]).expand( +- (agents_rot_angles.shape[0], -1) +- ), +- ), +- agents_rasters_from_world_tf, +- ) + +- # Batch rotating patches by rot_angles. +- rot_patches: Tensor = rotate(patch_data, torch.rad2deg(agents_rot_angles)) ++def batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches: List[RasterizedMapPatch], ++ num_agents: List[int], ++ max_agent_num: Optional[int] = None, ++ pad_value: Any = np.nan, ++) -> Tuple[Tensor, Tensor, Tensor]: + +- # Center cropping via slicing. +- rot_crop_patches = rot_patches[ +- ..., +- center_y - half_extent : center_y + half_extent, +- center_x - half_extent : center_x + half_extent, +- ] ++ # Batch transform map patches ++ ( ++ rot_crop_patches, ++ agents_resolution, ++ agents_rasters_from_world_tf ++ ) = batch_transform_raster_maps(map_patches) + ++ # Separate batch and agents + rot_crop_patches = split_pad_crop( + rot_crop_patches, num_agents, pad_value=pad_value, desired_size=max_agent_num + ) +@@ -278,7 +254,7 @@ def raster_map_collate_fn_scene( + agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num + ) + +- return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf ++ return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + + + def agent_collate_fn( +@@ -673,6 +649,10 @@ def agent_collate_fn( + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] ++ ++ lane_xyh,lane_adj,lane_mask,lane_ids = None,None,None,None ++ if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: ++ lane_xyh,lane_adj,lane_mask, lane_ids = _collate_lane_graph(batch_elems) + + agents_from_world_tf = torch.as_tensor( + np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), +@@ -712,6 +692,10 @@ def agent_collate_fn( + robot_fut_len=robot_future_len, + map_names=map_names, + maps=map_patches, ++ lane_xyh=lane_xyh, ++ lane_adj=lane_adj, ++ lane_mask=lane_mask, ++ lane_ids = lane_ids, + maps_resolution=maps_resolution, + vector_maps=vector_maps, + rasters_from_world_tf=rasters_from_world_tf, +@@ -781,6 +765,9 @@ def scene_collate_fn( + return_dict: bool, + pad_format: str, + batch_augments: Optional[List[BatchAugmentation]] = None, ++ desired_num_agents = None, ++ desired_hist_len=None, ++ desired_fut_len=None, + ) -> SceneBatch: + batch_size: int = len(batch_elems) + history_pad_dir: arr_utils.PadDirection = ( +@@ -800,6 +787,9 @@ def scene_collate_fn( + AgentObsTensor = TORCH_STATE_TYPES[obs_format] + + max_agent_num: int = max(elem.num_agents for elem in batch_elems) ++ if desired_num_agents is not None: ++ max_agent_num = max(max_agent_num,desired_num_agents) ++ + + centered_agent_state: List[AgentStateTensor] = list() + agents_types: List[Tensor] = list() +@@ -820,6 +810,10 @@ def scene_collate_fn( + + max_history_len: int = max(elem.agent_history_lens_np.max() for elem in batch_elems) + max_future_len: int = max(elem.agent_future_lens_np.max() for elem in batch_elems) ++ if desired_hist_len is not None: ++ max_history_len = max(max_history_len,desired_hist_len) ++ if desired_fut_len is not None: ++ max_future_len = max(max_future_len,desired_fut_len) + + robot_future: List[AgentObsTensor] = list() + robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) +@@ -951,6 +945,10 @@ def scene_collate_fn( + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + ++ lane_xyh,lane_adj,lane_mask, lane_ids = None,None,None,None ++ if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: ++ lane_xyh,lane_adj,lane_mask, lane_ids = _collate_lane_graph(batch_elems) ++ + centered_agent_from_world_tf = torch.as_tensor( + np.stack( + [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] +@@ -990,6 +988,7 @@ def scene_collate_fn( + agent_type=agents_types_t, + centered_agent_state=centered_agent_state_t, + agent_names=agent_names, ++ track_ids = None, + agent_hist=agents_histories_t, + agent_hist_extent=agents_history_extents_t, + agent_hist_len=agents_history_len, +@@ -1000,6 +999,10 @@ def scene_collate_fn( + robot_fut_len=robot_future_len, + map_names=map_names, + maps=map_patches, ++ lane_xyh = lane_xyh, ++ lane_adj = lane_adj, ++ lane_mask = lane_mask, ++ lane_ids = lane_ids, + maps_resolution=maps_resolution, + vector_maps=vector_maps, + rasters_from_world_tf=rasters_from_world_tf, +diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py +index 64f80eb..609e25e 100644 +--- a/src/trajdata/dataset.py ++++ b/src/trajdata/dataset.py +@@ -1,4 +1,5 @@ + import gc ++import json + import random + import time + from collections import defaultdict +@@ -49,7 +50,7 @@ from trajdata.data_structures import ( + from trajdata.dataset_specific import RawDataset + from trajdata.maps.map_api import MapAPI + from trajdata.parallel import ParallelDatasetPreprocessor, scene_paths_collate_fn +-from trajdata.utils import agent_utils, env_utils, scene_utils, string_utils ++from trajdata.utils import agent_utils, env_utils, py_utils, scene_utils, string_utils + from trajdata.utils.parallel_utils import parallel_iapply + + # TODO(bivanovic): Move this to a better place in the codebase. +@@ -111,6 +112,7 @@ class UnifiedDataset(Dataset): + cache_location: str = "~/.unified_data_cache", + rebuild_cache: bool = False, + rebuild_maps: bool = False, ++ save_index: bool = False, + num_workers: int = 0, + verbose: bool = False, + extras: Dict[str, Callable[..., np.ndarray]] = dict(), +@@ -152,12 +154,17 @@ class UnifiedDataset(Dataset): + cache_location (str, optional): Where to store and load preprocessed, cached data. Defaults to "~/.unified_data_cache". + rebuild_cache (bool, optional): If True, process and cache trajectory data even if it is already cached. Defaults to False. + rebuild_maps (bool, optional): If True, process and cache maps even if they are already cached. Defaults to False. ++ save_index (bool, optional): If True, save the resulting agent (or scene) data index after it is computed (speeding up subsequent initializations with the same argument values). + num_workers (int, optional): Number of parallel workers to use for dataset preprocessing and loading. Defaults to 0. + verbose (bool, optional): If True, print internal data loading information. Defaults to False. + extras (Dict[str, Callable[..., np.ndarray]], optional): Adds extra data to each batch element. Each Callable must take as input a filled {Agent,Scene}BatchElement and return an ndarray which will subsequently be added to the batch element's `extra` dict. + transforms (Iterable[Callable], optional): Allows for custom modifications of batch elements. Each Callable must take in a filled {Agent,Scene}BatchElement and return a {Agent,Scene}BatchElement. + rank (int, optional): Proccess rank when using torch DistributedDataParallel for multi-GPU training. Only the rank 0 process will be used for caching. + """ ++ self.desired_data: List[str] = desired_data ++ self.scene_description_contains: Optional[ ++ List[str] ++ ] = scene_description_contains + self.centric: str = centric + self.desired_dt: float = desired_dt + +@@ -204,6 +211,11 @@ class UnifiedDataset(Dataset): + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "collate": False, ++ # Whether loaded maps should be stored in memory (memoized) for later re-use. ++ # For datasets which provide full maps ahead-of-time (i.e., all except Waymo), ++ # this should be True. However, for Waymo it should be False because maps ++ # are already partitioned geographically and keeping them around significantly grows memory. ++ "keep_in_memory": True, + } + ) + if self.desired_dt is not None: +@@ -233,13 +245,15 @@ class UnifiedDataset(Dataset): + self.torch_obs_type = TORCH_STATE_TYPES[obs_format] + + # Ensuring scene description queries are all lowercase +- if scene_description_contains is not None: +- scene_description_contains = [s.lower() for s in scene_description_contains] ++ if self.scene_description_contains is not None: ++ self.scene_description_contains = [ ++ s.lower() for s in self.scene_description_contains ++ ] + + self.envs: List[RawDataset] = env_utils.get_raw_datasets(data_dirs) + self.envs_dict: Dict[str, RawDataset] = {env.name: env for env in self.envs} + +- matching_datasets: List[SceneTag] = self.get_matching_scene_tags(desired_data) ++ matching_datasets: List[SceneTag] = self._get_matching_scene_tags(desired_data) + if self.verbose: + print( + "Loading data for matched scene tags:", +@@ -249,7 +263,7 @@ class UnifiedDataset(Dataset): + + self._map_api: Optional[MapAPI] = None + if self.incl_vector_map: +- self._map_api = MapAPI(self.cache_path) ++ self._map_api = MapAPI(self.cache_path, keep_in_memory=vector_map_params.get("keep_in_memory", True)) + + all_scenes_list: Union[List[SceneMetadata], List[Scene]] = list() + for env in self.envs: +@@ -257,8 +271,8 @@ class UnifiedDataset(Dataset): + all_data_cached: bool = False + all_maps_cached: bool = not env.has_maps or not require_map_cache + if self.env_cache.env_is_cached(env.name) and not self.rebuild_cache: +- scenes_list: List[Scene] = self.get_desired_scenes_from_env( +- matching_datasets, scene_description_contains, env ++ scenes_list: List[Scene] = self._get_desired_scenes_from_env( ++ matching_datasets, env + ) + + all_data_cached: bool = all( +@@ -314,9 +328,9 @@ class UnifiedDataset(Dataset): + ): + distributed.barrier() + +- scenes_list: List[SceneMetadata] = self.get_desired_scenes_from_env( +- matching_datasets, scene_description_contains, env +- ) ++ scenes_list: List[ ++ SceneMetadata ++ ] = self._get_desired_scenes_from_env(matching_datasets, env) + + if self.incl_vector_map and env.metadata.map_locations is not None: + # env.metadata.map_locations can be none for map-containing +@@ -330,7 +344,7 @@ class UnifiedDataset(Dataset): + all_scenes_list += scenes_list + + # List of cached scene paths. +- scene_paths: List[Path] = self.preprocess_scene_data( ++ scene_paths: List[Path] = self._preprocess_scene_data( + all_scenes_list, num_workers + ) + if self.verbose: +@@ -343,7 +357,11 @@ class UnifiedDataset(Dataset): + data_index: Union[ + List[Tuple[str, int, np.ndarray]], + List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], +- ] = self.get_data_index(num_workers, scene_paths) ++ ] ++ if self._index_cache_path().exists(): ++ data_index = self._load_data_index() ++ else: ++ data_index = self._get_data_index(num_workers, scene_paths) + + # Done with this list. Cutting memory usage because + # of multiprocessing later on. +@@ -360,8 +378,99 @@ class UnifiedDataset(Dataset): + ) + self._data_len: int = len(self._data_index) + ++ # Use only rank 0 process for caching when using multi-GPU torch training. ++ if save_index and rank == 0: ++ if self._index_cache_path().exists(): ++ print( ++ "WARNING: Overwriting already-cached data index (since save_index is True).", ++ flush=True, ++ ) ++ ++ self._cache_data_index(data_index) ++ if ( ++ distributed.is_initialized() ++ and distributed.get_world_size() > 1 ++ ): ++ distributed.barrier() + self._cached_batch_elements = None + ++ def _index_cache_path( ++ self, ret_args: bool = False ++ ) -> Union[Path, Tuple[Path, Dict[str, Any]]]: ++ # Whichever UnifiedDataset arguments affect data indexing are captured ++ # and hashed together here. ++ impactful_args: Dict[str, Any] = { ++ "desired_data": tuple(self.desired_data), ++ "scene_description_contains": tuple(self.scene_description_contains) ++ if self.scene_description_contains is not None ++ else None, ++ "centric": self.centric, ++ "desired_dt": self.desired_dt, ++ "history_sec": self.history_sec, ++ "future_sec": self.future_sec, ++ "incl_robot_future": self.incl_robot_future, ++ "only_types": tuple(t.name for t in self.only_types) ++ if self.only_types is not None ++ else None, ++ "only_predict": tuple(t.name for t in self.only_predict) ++ if self.only_predict is not None ++ else None, ++ "no_types": tuple(t.name for t in self.no_types) ++ if self.no_types is not None ++ else None, ++ "ego_only": self.ego_only, ++ } ++ index_hash: str = py_utils.hash_dict(impactful_args) ++ index_cache_path: Path = self.cache_path / "data_indexes" / index_hash ++ ++ if ret_args: ++ return index_cache_path, impactful_args ++ else: ++ return index_cache_path ++ ++ def _cache_data_index( ++ self, ++ data_index: Union[ ++ List[Tuple[str, int, np.ndarray]], ++ List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], ++ ], ++ ) -> None: ++ index_cache_dir, index_args = self._index_cache_path(ret_args=True) ++ ++ # Create it if it doesn't exist yet. ++ index_cache_dir.mkdir(parents=True, exist_ok=True) ++ ++ index_cache_file: Path = index_cache_dir / "data_index.dill" ++ with open(index_cache_file, "wb") as f: ++ dill.dump(data_index, f) ++ ++ args_file: Path = index_cache_dir / "index_args.json" ++ with open(args_file, "w") as f: ++ json.dump(index_args, f, indent=4) ++ ++ print( ++ f"Cached data index to {str(index_cache_file)}", ++ flush=True, ++ ) ++ ++ def _load_data_index( ++ self, ++ ) -> Union[ ++ List[Tuple[str, int, np.ndarray]], ++ List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], ++ ]: ++ index_cache_file: Path = self._index_cache_path() / "data_index.dill" ++ with open(index_cache_file, "rb") as f: ++ data_index = dill.load(f) ++ ++ if self.verbose: ++ print( ++ f"Loaded data index from {str(index_cache_file)}", ++ flush=True, ++ ) ++ ++ return data_index ++ + def load_or_create_cache( + self, cache_path: str, num_workers=0, filter_fn=None + ) -> None: +@@ -494,7 +603,7 @@ class UnifiedDataset(Dataset): + f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." + ) + +- def get_data_index( ++ def _get_data_index( + self, num_workers: int, scene_paths: List[Path] + ) -> Union[ + List[Tuple[str, int, np.ndarray]], +@@ -681,13 +790,16 @@ class UnifiedDataset(Dataset): + return_dict=return_dict, + pad_format=pad_format, + batch_augments=batch_augments, ++ desired_num_agents = self.max_agent_num, ++ desired_hist_len = int(self.history_sec[1]/self.desired_dt), ++ desired_fut_len = int(self.future_sec[1]/self.desired_dt), + ) + else: + raise ValueError(f"{self.centric}-centric data batches are not supported.") + + return collate_fn + +- def get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: ++ def _get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: + # if queries is None: + # return list(chain.from_iterable(env.components for env in self.envs)) + +@@ -708,10 +820,9 @@ class UnifiedDataset(Dataset): + + return matching_scene_tags + +- def get_desired_scenes_from_env( ++ def _get_desired_scenes_from_env( + self, + scene_tags: List[SceneTag], +- scene_description_contains: Optional[List[str]], + env: RawDataset, + ) -> Union[List[Scene], List[SceneMetadata]]: + scenes_list: Union[List[Scene], List[SceneMetadata]] = list() +@@ -721,14 +832,14 @@ class UnifiedDataset(Dataset): + if env.name in scene_tag: + scenes_list += env.get_matching_scenes( + scene_tag, +- scene_description_contains, ++ self.scene_description_contains, + self.env_cache, + self.rebuild_cache, + ) + + return scenes_list + +- def preprocess_scene_data( ++ def _preprocess_scene_data( + self, + scenes_list: Union[List[SceneMetadata], List[Scene]], + num_workers: int, +diff --git a/src/trajdata/dataset_specific/carla/__init__.py b/src/trajdata/dataset_specific/carla/__init__.py +new file mode 100644 +index 0000000..8854e8a +--- /dev/null ++++ b/src/trajdata/dataset_specific/carla/__init__.py +@@ -0,0 +1 @@ ++from .carla_dataset import CarlaDataset +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/carla/carla_dataset.py b/src/trajdata/dataset_specific/carla/carla_dataset.py +new file mode 100644 +index 0000000..2692e26 +--- /dev/null ++++ b/src/trajdata/dataset_specific/carla/carla_dataset.py +@@ -0,0 +1,415 @@ ++import warnings ++from copy import deepcopy ++from pathlib import Path ++from typing import Any, Dict, List, Optional, Tuple, Type, Union ++ ++import pandas as pd ++from nuscenes.eval.prediction.splits import NUM_IN_TRAIN_VAL ++from nuscenes.map_expansion.map_api import NuScenesMap, locations ++from nuscenes.nuscenes import NuScenes ++from nuscenes.utils.splits import create_splits_scenes ++from tqdm import tqdm ++ ++from trajdata.caching import EnvCache, SceneCache ++from trajdata.data_structures.agent import ( ++ Agent, ++ AgentMetadata, ++ AgentType, ++ FixedExtent, ++ VariableExtent, ++) ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene, SceneMetadata ++from trajdata.data_structures.scene_tag import SceneTag ++from trajdata.dataset_specific.nusc import nusc_utils ++from trajdata.dataset_specific.raw_dataset import RawDataset ++from trajdata.dataset_specific.scene_records import CarlaSceneRecord ++from trajdata.maps import VectorMap ++from trajdata.maps.map_api import MapAPI ++ ++from pdb import set_trace as st ++import glob, os ++import pickle ++from collections import defaultdict ++import re ++ ++import torch ++import numpy as np ++ ++carla_to_trajdata_object_type = { ++ 0: AgentType.VEHICLE, ++ 1: AgentType.BICYCLE, # Motorcycle ++ 2: AgentType.BICYCLE, ++ 3: AgentType.PEDESTRIAN, ++ 4: AgentType.UNKNOWN, ++ # ?: AgentType.STATIC, ++} ++ ++def create_splits_scenes(data_dir:str) -> Dict[str, List[str]]: ++ all_scenes = {} ++ all_scenes['train'] = [scene_path.split('/')[-1] for scene_path in glob.glob(data_dir+'/train/route*')] ++ all_scenes['val'] = [scene_path.split('/')[-1] for scene_path in glob.glob(data_dir+'/val/route*')] ++ return all_scenes ++ ++# TODO: (Yulong) format in object class ++def tracklet_to_pred(tracklet_mem,ego=False): ++ if ego: ++ x, y, z = np.split(tracklet_mem['location'],3,axis=-1) ++ hx, hy, hz = np.split(np.deg2rad(tracklet_mem['rotation']),3,axis=-1) ++ vx, vy, _ = np.split(tracklet_mem['velocity'],3,axis=-1) ++ ax, ay, _ = np.split(tracklet_mem['acceleration'],3,axis=-1) ++ else: ++ x, y, z = np.split(tracklet_mem['location'][0,:,-1,:],3,axis=-1) ++ hx, hy, hz = np.split(np.deg2rad(tracklet_mem['rotation'])[0,:,-1,:],3,axis=-1) ++ vx, vy, _ = np.split(tracklet_mem['velocity'][0,:,-1,:],3,axis=-1) ++ ax, ay, _ = np.split(tracklet_mem['acceleration'][0,:,-1,:],3,axis=-1) ++ pred_state = np.concatenate([ ++ x, -y, z, vx, -vy, ax, -ay, -hy ++ ],axis=-1) ++ return pred_state ++ ++def CarlaTracking(dataroot): ++ dataset_obj = defaultdict(lambda: defaultdict(dict)) ++ frames = list(dataroot.glob('*/route*/metadata/tracking/*.pkl')) ++ for frame in frames: ++ frame = str(frame) ++ with open(frame, 'rb') as f: ++ track_mem = pickle.load(f) ++ frame_idx = frame.split('/')[-1].split('.')[0] ++ scene = frame.split('/')[-4] ++ dataset_obj[scene]["all"][frame_idx] = track_mem ++ ++ frames = list(dataroot.glob('*/route*/metadata/ego/*.pkl')) ++ for frame in frames: ++ frame = str(frame) ++ with open(frame, 'rb') as f: ++ track_mem = pickle.load(f) ++ frame_idx = frame.split('/')[-1].split('.')[0] ++ scene = frame.split('/')[-4] ++ dataset_obj[scene]["ego"][frame_idx] = track_mem ++ ++ return dataset_obj ++ ++def agg_ego_data(all_frames): ++ agent_data = [] ++ agent_frame = [] ++ for frame_idx, frame_info in enumerate(all_frames): ++ pred_state = tracklet_to_pred(frame_info, ego=True) ++ agent_data.append(pred_state) ++ agent_frame.append(frame_idx) ++ ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ ("ego", idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name="ego", ++ agent_type=AgentType.VEHICLE, ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=all_frames[-1]["size"][1], width=all_frames[-1]["size"][0], height=all_frames[-1]["size"][2] ++ ), ++ ) ++ return Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ ) ++ ++def agg_agent_data(all_frames, agent_info, frame_idx): ++ agent_data = [] ++ agent_frame = [] ++ Agent_list = [] ++ for frame_idx, frame_info in enumerate(all_frames): ++ for idx in range(frame_info['id'].shape[1]): ++ if frame_info["id"][0,idx,0] == agent_info["id"]: ++ # two segments of tracking ++ if len(agent_frame) > 0 and frame_idx != agent_frame[-1] + 1: ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ (f'{int(agent_info["id"])}_{len(Agent_list)}', idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name=f'{int(agent_info["id"])}_{len(Agent_list)}', ++ agent_type=carla_to_trajdata_object_type[agent_info["cls"]], ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=agent_info["size"][1], width=agent_info["size"][0], height=agent_info["size"][2] ++ ), ++ ) ++ Agent_list.append(Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ )) ++ ++ agent_data = [] ++ agent_frame = [] ++ ++ ++ pred_state = tracklet_to_pred(frame_info)[idx] ++ agent_data.append(pred_state) ++ agent_frame.append(frame_idx) ++ ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ (str(int(agent_info["id"])), idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name=str(int(agent_info["id"])), ++ agent_type=carla_to_trajdata_object_type[agent_info["cls"]], ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=agent_info["size"][1], width=agent_info["size"][0], height=agent_info["size"][2] ++ ), ++ ) ++ ++ Agent_list.append(Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ )) ++ return Agent_list ++ ++ ++ ++class CarlaDataset(RawDataset): ++ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ++ # We're using the nuScenes prediction challenge split here. ++ # See https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/eval/prediction/splits.py ++ # for full details on how the splits are obtained below. ++ all_scene_splits: Dict[str, List[str]] = create_splits_scenes(data_dir) ++ ++ train_scenes: List[str] = deepcopy(all_scene_splits["train"]) ++ NUM_IN_TRAIN_VAL = round(len(train_scenes)*0.25) ++ all_scene_splits["train"] = train_scenes[NUM_IN_TRAIN_VAL:] ++ all_scene_splits["train_val"] = train_scenes[:NUM_IN_TRAIN_VAL] ++ ++ if env_name == 'carla': ++ carla_scene_splits: Dict[str, List[str]] = { ++ k: all_scene_splits[k] for k in ["train", "train_val", "val"] ++ } ++ ++ # nuScenes possibilities are the Cartesian product of these ++ dataset_parts: List[Tuple[str, ...]] = [ ++ ("train", "train_val", "val"), ++ ] ++ else: ++ raise ValueError(f"Unknown nuScenes environment name: {env_name}") ++ ++ self.scene_splits = carla_scene_splits ++ ++ # Inverting the dict from above, associating every scene with its data split. ++ carla_scene_split_map: Dict[str, str] = { ++ v_elem: k for k, v in carla_scene_splits.items() for v_elem in v ++ } ++ return EnvMetadata( ++ name=env_name, ++ data_dir=data_dir, ++ dt=nusc_utils.NUSC_DT, ++ parts=dataset_parts, ++ scene_split_map=carla_scene_split_map, ++ # The location names should match the map names used in ++ # the unified data cache. ++ map_locations=tuple([]), ++ ) ++ ++ def load_dataset_obj(self, verbose: bool = False) -> None: ++ if verbose: ++ print(f"Loading {self.name} dataset...", flush=True) ++ ++ self.dataset_obj = CarlaTracking( ++ dataroot=self.metadata.data_dir ++ ) ++ ++ def _get_matching_scenes_from_obj( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[SceneMetadata]: ++ all_scenes_list: List[CarlaSceneRecord] = list() ++ ++ scenes_list: List[SceneMetadata] = list() ++ for idx, scene_record in enumerate(self.dataset_obj): ++ scene_name: str = scene_record ++ scene_location: str = re.match('.*(Town\d+)', scene_record)[1] ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ scene_length: int = len(self.dataset_obj[scene_record]) ++ ++ # Saving all scene records for later caching. ++ all_scenes_list.append( ++ CarlaSceneRecord( ++ scene_name, scene_location, scene_length, idx ++ ) ++ ) ++ ++ if scene_split in scene_tag: ++ ++ scene_metadata = SceneMetadata( ++ env_name=self.metadata.name, ++ name=scene_name, ++ dt=self.metadata.dt, ++ raw_data_idx=idx, ++ ) ++ scenes_list.append(scene_metadata) ++ ++ self.cache_all_scenes_list(env_cache, all_scenes_list) ++ return scenes_list ++ ++ def _get_matching_scenes_from_cache( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[Scene]: ++ all_scenes_list: List[CarlaSceneRecord] = env_cache.load_env_scenes_list( ++ self.name ++ ) ++ ++ scenes_list: List[SceneMetadata] = list() ++ for scene_record in all_scenes_list: ++ ( ++ scene_name, ++ scene_location, ++ scene_length, ++ data_idx, ++ ) = scene_record ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ ++ if scene_split in scene_tag: ++ scene_metadata = Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ None, # This isn't used if everything is already cached. ++ ) ++ scenes_list.append(scene_metadata) ++ ++ return scenes_list ++ ++ def get_scene(self, scene_info: SceneMetadata) -> Scene: ++ _, route_name, _, data_idx = scene_info ++ ++ scene_record = sorted(self.dataset_obj[route_name]) ++ scene_name: str = route_name ++ scene_location: str = re.match('.*(Town\d+)', route_name)[1] ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ scene_length: int = len(self.dataset_obj[route_name]['ego']) ++ ++ return Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ scene_record, ++ ) ++ ++ def get_agent_info( ++ self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ++ ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: ++ ego_agent_info: AgentMetadata = AgentMetadata( ++ name="ego", ++ agent_type=AgentType.VEHICLE, ++ first_timestep=0, ++ last_timestep=scene.length_timesteps - 1, ++ extent=FixedExtent(length=4.084, width=1.730, height=1.562), #TODO: (yulong) replace with carla ego ++ ) ++ ++ agent_presence: List[List[AgentMetadata]] = [ ++ [ego_agent_info] for _ in range(scene.length_timesteps) ++ ] ++ ++ agent_data_list: List[pd.DataFrame] = list() ++ existing_agents: Dict[str, AgentMetadata] = dict() ++ ++ all_frames = [self.dataset_obj[scene.name]["all"][key] for key in sorted(self.dataset_obj[scene.name]["all"])] ++ ++ # frame_idx_dict = { ++ # frame_dict: idx for idx, frame_dict in enumerate(all_frames) ++ # } ++ for frame_idx, frame_info in enumerate(all_frames): ++ for idx in range(frame_info['id'].shape[1]): ++ if str(int(frame_info["id"][0,idx,0])) in existing_agents: ++ continue ++ ++ agent_info = {"id": frame_info["id"][0,idx,0], ++ "cls": frame_info["cls"][0,idx,:].argmax(), ++ "size": frame_info["size"][0,idx,:] } ++ # if not agent_info["next"]: ++ # # There are some agents with only a single detection to them, we don't care about these. ++ # continue ++ ++ agent_list: List[Agent] = agg_agent_data( ++ all_frames, agent_info, frame_idx ++ ) ++ for agent in agent_list: ++ for scene_ts in range( ++ agent.metadata.first_timestep, agent.metadata.last_timestep + 1 ++ ): ++ agent_presence[scene_ts].append(agent.metadata) ++ ++ existing_agents[agent.name] = agent.metadata ++ ++ agent_data_list.append(agent.data) ++ ++ ego_all_frames = [self.dataset_obj[scene.name]["ego"][key] for key in sorted(self.dataset_obj[scene.name]["ego"])] ++ ++ ego_agent: Agent = agg_ego_data(ego_all_frames) ++ agent_data_list.append(ego_agent.data) ++ ++ agent_list: List[AgentMetadata] = [ego_agent_info] + list( ++ existing_agents.values() ++ ) ++ ++ cache_class.save_agent_data(pd.concat(agent_data_list), cache_path, scene) ++ ++ return agent_list, agent_presence ++ ++ ++ def cache_maps( ++ self, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ ++ map_api = MapAPI(cache_path) ++ for carla_town in [f"Town0{x}" for x in range(1, 8)] + ["Town10", "Town10HD"]: # ["main"]: ++ vec_map = map_api.get_map( ++ f"drivesim:main" if carla_town == "main" else f"carla:{carla_town}", ++ incl_road_lanes=True, ++ incl_road_areas=True, ++ incl_ped_crosswalks=True, ++ incl_ped_walkways=True, ++ ) ++ map_cache_class.finalize_and_cache_map(cache_path, vec_map, map_params) +diff --git a/src/trajdata/dataset_specific/drivesim/__init__.py b/src/trajdata/dataset_specific/drivesim/__init__.py +new file mode 100644 +index 0000000..0167f39 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/__init__.py +@@ -0,0 +1 @@ ++from .drivesim_dataset import DrivesimDataset +diff --git a/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py b/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py +new file mode 100644 +index 0000000..865e451 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py +@@ -0,0 +1,116 @@ ++import warnings ++from copy import deepcopy ++from pathlib import Path ++from typing import Any, Dict, List, Optional, Tuple, Type, Union ++from collections import defaultdict ++ ++import pandas as pd ++ ++from tqdm import tqdm ++ ++from trajdata.caching import EnvCache, SceneCache ++from trajdata.data_structures.agent import ( ++ Agent, ++ AgentMetadata, ++ AgentType, ++ FixedExtent, ++ VariableExtent, ++) ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene, SceneMetadata ++from trajdata.data_structures.scene_tag import SceneTag ++from trajdata.dataset_specific.raw_dataset import RawDataset ++from trajdata.dataset_specific.scene_records import DrivesimSceneRecord ++from trajdata.maps import VectorMap ++ ++DRIVESIM_DT = 0.1 ++ ++class DrivesimDataset(RawDataset): ++ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ++ ++ return EnvMetadata( ++ name="drivesim", ++ data_dir=data_dir, ++ dt=DRIVESIM_DT, ++ parts=[("train",),("main",)], ++ scene_split_map=defaultdict(lambda: "train"), ++ # The location names should match the map names used in ++ # the unified data cache. ++ map_locations=("main",), ++ ) ++ ++ def load_dataset_obj(self, verbose: bool = False) -> None: ++ pass ++ ++ def _get_matching_scenes_from_obj( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[SceneMetadata]: ++ raise NotImplementedError() ++ ++ def _get_matching_scenes_from_cache( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[Scene]: ++ all_scenes_list: List[DrivesimSceneRecord] = env_cache.load_env_scenes_list( ++ self.name ++ ) ++ scenes_list: List[SceneMetadata] = list() ++ for scene_record in all_scenes_list: ++ ( ++ scene_name, ++ scene_location, ++ scene_length, ++ scene_desc, ++ data_idx, ++ ) = scene_record ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ ++ if scene_location.split("-")[0] in scene_tag and scene_split in scene_tag: ++ if scene_desc_contains is not None and not any( ++ desc_query in scene_desc for desc_query in scene_desc_contains ++ ): ++ continue ++ ++ scene_metadata = Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ None, # This isn't used if everything is already cached. ++ scene_desc, ++ ) ++ scenes_list.append(scene_metadata) ++ ++ return scenes_list ++ ++ def get_scene(self, scene_info: SceneMetadata) -> Scene: ++ raise NotImplementedError() ++ ++ def get_agent_info( ++ self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ++ ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: ++ raise NotImplementedError() ++ ++ def cache_map( ++ self, ++ map_name: str, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ raise NotImplementedError() ++ ++ def cache_maps( ++ self, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ raise NotImplementedError() +diff --git a/src/trajdata/dataset_specific/drivesim/nvmap_utils.py b/src/trajdata/dataset_specific/drivesim/nvmap_utils.py +new file mode 100644 +index 0000000..0f41cc2 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/nvmap_utils.py +@@ -0,0 +1,1792 @@ ++ ++import os, shutil, copy ++import glob ++import json ++from collections import OrderedDict, deque ++ ++import numpy as np ++import cv2 ++ ++import matplotlib as mpl ++# mpl.use('Agg') ++import matplotlib.pyplot as plt ++from matplotlib.patches import Circle, Polygon, PathPatch, Ellipse ++from matplotlib.path import Path ++ ++import sys ++ ++# NOTE: NVMap documentation https://nvmapspec.drive.nvda.io/map_layer_format/index.html ++ ++# transformation matrix of the base pose of the global cooridinate system (defined in ECEF) ++# will be used to transform each road segment to global frame. ++# rivermark is used for autolabels ++GLOBAL_BASE_POSE_RIVERMARK = np.array([[ 0.7661314567952979, -0.47780431092164843, -0.42982046411659275, -2684537.8381108316], ++ [ 0.06759529959471867, 0.7249870917843614, -0.6854375188292178, -4305029.705512633], ++ [ 0.6391192896333319, 0.49608140179888094, 0.5877327423309361, 3852303.4349504006], ++ [ 0.0, 0.0, 0.0, 1.0]]) ++# endeavor can be used for map if desired ++GLOBAL_BASE_POSE_ENDEAVOR = np.array([[0.892451945618358, 0.11765375230214345, -0.43553084773783024, -2686874.9549495797], ++ [-0.40233607907375113, 0.644307257120398, -0.6503797643665964, -4305568.426410184], ++ [0.2040960661981692, 0.7556624596942837, 0.6223496145826857, 3850100.0530229895], ++ [ 0.0, 0.0, 0.0, 1.0]]) ++# this is the one that will be used for the map and trajectories ++GLOBAL_BASE_POSE = GLOBAL_BASE_POSE_ENDEAVOR ++# GLOBAL_BASE_POSE = GLOBAL_BASE_POSE_RIVERMARK ++ ++MAX_ENDEAVOR_FRAME = 3613 ++ ++RIVERMARK_VIZ = False ++# 0 : drivable only ++# 1 : + ego ++# 2 : + lines ++# 3 : + other non-extra cars ++# 4 : + all cars ++RIVERMARK_VIZ_LAYER = 4 ++NVCOLOR = (118.0/255, 185.0/255, 0.0/255) ++NVCOLOR2 = (55.0/255, 225.0/255, 0.0/255) ++# EXTRA_DYN_LWH = [5.087, 2.307, 1.856] ++EXTRA_DYN_LWH = [4.387, 1.907, 1.656] ++ ++NUSC_EXPAND_LEN = 0.0 #2.0 # meters to expand drivable area outward ++ ++AUTOLABEL_DT = 0.1 # sec ++ ++NVMAP_LAYERS = {'lane_dividers_v1', 'lanes_v1', 'lane_channels_v1'} ++SEGMENT_GRAPH_LAYER = 'segment_graph_v1' ++ ++# maps track ids (which occur first in time) to other tracks that are ++# actually the same object ++# NOTE: each association list must be sorted in temporal order ++TRACK_ASSOC = { ++ # frame 3000 - 3545 sequence ++ 2486 : [2584, 2669], ++ 2546 : [2618], ++ 2489 : [2558], ++ 2496 : [2553], ++ 2515 : [2578], ++ 2525 : [2615], ++ 2555 : [2609], ++ 2546 : [2618, 2673], ++ # 2491 : [2566, 2631], ++ 2491 : [2631], ++ # 2566 : [2631], ++ 2484 : [2692, 2729, 2745, 2766], ++ 2686 : [2811], ++ 2618 : [2673], ++ 2592 : [2719], ++ 2251 : [2847], ++ 2269 : [2734, 2853], ++ 1280 : [2768, 2790], ++ 2631 : [2685], ++ 2561 : [2699], ++ # frame 570 - 990 sequences ++ 1298 : [1360], ++ 1282 : [1362], ++ 1241 : [1366], ++ 1305 : [1398], ++ 1286 : [1377], ++ 1327 : [1409], ++ 1444 : [1576], ++ 1326 : [1455], ++ 1376 : [1425, 1435, 1444], ++ 1345 : [1445, 1456, 1476], ++ 1346 : [1477, 1491], ++ # frame 270 - 600 ++ 1067 : [1193], ++ 1109 : [1206], ++ 1126 : [1229], ++ 1127 : [1234], ++ 1137 : [1255], ++ 1150 : [1251], ++ 1168 : [1246], ++ 1184 : [1269], ++ 1203 : [1225], ++ 1196 : [1307], ++ 1211 : [1315], ++ 1191 : [1297], ++ 1183 : [1276], ++ 1188 : [1274], ++ 1178 : [1250], ++ # frame 1440 - 1860 ++ 1717 : [1761], ++ 1712 : [1778], ++ 1704 : [1786], ++ 1733 : [1803], ++ 1708 : [1722,1811], ++ 1674 : [1732,1740], ++ 1760 : [1837], ++ 1824 : [1836], ++ 1793 : [1872], ++ 1812 : [1869], ++ 1748 : [1812], ++ 1826 : [1864], ++ 1723 : [1756], ++ # rivermark ++ 9409 : [9517] ++} ++# false positive (or extrapolation is bad) ++# 2566 ++TRACK_REMOVE = {2676, 2687, 2590, 2833, 1199, 9413, 9449, 9437, 9589, 2588, 2669, 2562, 2516, 2566} # 9449 is rivermark dynamic, and 9437 is rivermark trash can ++ ++def check_single_veh_coll(traj_tgt, lw_tgt, traj_others, lw_others): ++ ''' ++ Checks if the target trajectory collides with each of the given other trajectories. ++ ++ Assumes all trajectories and attributes are UNNORMALIZED. Handles nan frames in traj_others by simply skipping. ++ ++ :param traj_tgt: (T x 4) ++ :param lw_tgt: (2, ) ++ :param traj_others: (N x T x 4) ++ :param lw_others: (N x 2) ++ ++ :returns veh_coll: (N) ++ :returns coll_time: (N) ++ ''' ++ from shapely.geometry import Polygon ++ ++ NA, FT, _ = traj_others.shape ++ ++ veh_coll = np.zeros((NA, FT), dtype=np.bool) ++ poly_cache = dict() # for the tgt polygons since used many times ++ for aj in range(NA): ++ for t in range(FT): ++ # compute iou ++ if t not in poly_cache: ++ ai_state = traj_tgt[t, :] ++ if np.sum(np.isnan(ai_state)) > 0: ++ continue ++ ai_corners = get_corners(ai_state, lw_tgt) ++ ai_poly = Polygon(ai_corners) ++ poly_cache[t] = ai_poly ++ else: ++ ai_poly = poly_cache[t] ++ ++ aj_state = traj_others[aj, t, :] ++ if np.sum(np.isnan(aj_state)) > 0: ++ continue ++ aj_corners = get_corners(aj_state, lw_others[aj]) ++ aj_poly = Polygon(aj_corners) ++ cur_iou = ai_poly.intersection(aj_poly).area / ai_poly.union(aj_poly).area ++ if cur_iou > 0.02: ++ veh_coll[aj, t] = True ++ ++ return veh_coll ++ ++def plt_color(i): ++ clist = plt.rcParams['axes.prop_cycle'].by_key()['color'] ++ return clist[i] ++ ++def get_rot(h): ++ return np.array([ ++ [np.cos(h), np.sin(h)], ++ [-np.sin(h), np.cos(h)], ++ ]) ++ ++def get_corners(box, lw): ++ l, w = lw ++ simple_box = np.array([ ++ [-l/2., -w/2.], ++ [l/2., -w/2.], ++ [l/2., w/2.], ++ [-l/2., w/2.], ++ ]) ++ h = np.arctan2(box[3], box[2]) ++ rot = get_rot(h) ++ simple_box = np.dot(simple_box, rot) ++ simple_box += box[:2] ++ return simple_box ++ ++def plot_box(box, lw, color='g', alpha=0.7, no_heading=False): ++ l, w = lw ++ h = np.arctan2(box[3], box[2]) ++ simple_box = get_corners(box, lw) ++ ++ arrow = np.array([ ++ box[:2], ++ box[:2] + l/2.*np.array([np.cos(h), np.sin(h)]), ++ ]) ++ ++ plt.fill(simple_box[:, 0], simple_box[:, 1], color=color, edgecolor='k', alpha=alpha, linewidth=1.0, zorder=3) ++ if not no_heading: ++ # plt.plot(arrow[:, 0], arrow[:, 1], color, alpha=1.0) ++ plt.plot(arrow[:, 0], arrow[:, 1], 'k', alpha=alpha, zorder=3) ++ ++def create_video(img_path_form, out_path, fps): ++ ''' ++ Creates a video from a format for frame e.g. 'data_out/frame%04d.png'. ++ Saves in out_path. ++ ''' ++ import subprocess ++ # if RIVERMARK_VIZ: ++ # ffmpeg_cmd = ['ffmpeg', '-y', '-i', img_path_form, ++ # '-vf', 'transpose=2', img_path_form] ++ # subprocess.run(ffmpeg_cmd) ++ ++ ffmpeg_cmd = ['ffmpeg', '-y', '-r', str(fps), '-i', img_path_form, ++ '-vcodec', 'libx264', '-crf', '18', '-pix_fmt', 'yuv420p', out_path] ++ subprocess.run(ffmpeg_cmd) ++ ++def debug_viz_vid(seg_dict, prefix, poses, poses_valid, poses_lwh, ++ comp_out_path='./out/dev_nvmap', ++ fps=10, ++ subsamp=3, ++ pose_ids=None, ++ **kwargs): ++ poses = poses[:,::subsamp] ++ poses_valid = poses_valid[:,::subsamp] ++ T = poses.shape[1] ++ out_dir = os.path.join(comp_out_path, prefix) ++ if not os.path.exists(out_dir): ++ os.makedirs(out_dir) ++ for t in range(T): ++ print('rendering frame %d...' % (t)) ++ debug_viz_segs(seg_dict, 'frame_%06d' % (t), poses[:,t:t+1], poses_valid[:,t:t+1], poses_lwh, ++ comp_out_path=out_dir, ++ pose_ids=pose_ids, ++ ego_traj=poses[0], ++ **kwargs) ++ create_video(os.path.join(out_dir, 'frame_%06d.jpg'), out_dir + '.mp4', fps) ++ ++def debug_viz_segs(seg_dict, prefix, poses=None, poses_valid=None, poses_lwh=None, pose_ids=None, ++ comp_out_path='./out/dev_nvmap', ++ extent=80, ++ grid=True, ++ show_ticks=True, ++ dpi=100, ++ ego_traj=None): ++ ''' ++ Visualize segments in the given dictionary. ++ ++ :param seg_dict: segments dictionary ++ :param prefix: prefix to save figure ++ :param poses: NxTx4x4 trajectory for N vehicles that will be plotted if given. ++ ''' ++ if not os.path.exists(comp_out_path): ++ os.makedirs(comp_out_path) ++ ++ # fig = plt.figure() ++ fig = plt.figure(dpi=dpi) ++ ++ origins = [] ++ arr_len = 10.0 ++ for _, seg in seg_dict.items(): ++ if poses is not None: ++ dist2ego = np.linalg.norm(seg.local2world[:2, -1] - poses[0,0,:2,-1]) ++ if dist2ego > 200: ++ continue ++ # plot layers ++ if 'drivable_area' in seg.layers: ++ # draw drivable area first so under everything else ++ for drivable_poly in seg.layers['drivable_area']: ++ polypatch = Polygon(drivable_poly[:,:2], ++ color='darkgray', ++ alpha=1.0, ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(polypatch) ++ ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 2: ++ # only drivable area ++ continue ++ ++ for layer_k, layer_v in seg.layers.items(): ++ if layer_k in {'lane_dividers', 'lane_divider'}: ++ for lane_div in layer_v: ++ polyline = lane_div.polyline if isinstance(lane_div, LaneDivider) else lane_div ++ linepatch = PathPatch(Path(polyline[:,:2]), ++ fill=False, ++ color='gold', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ elif layer_k in {'road_dividers', 'road_divider'}: ++ for road_div in layer_v: ++ linepatch = PathPatch(Path(road_div[:,:2]), ++ fill=False, ++ color='orange', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ elif layer_k in {'road_boundaries'}: ++ for road_bound in layer_v: ++ linepatch = PathPatch(Path(road_bound.polyline[:,:2]), ++ fill=False, ++ color='darkgray', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ # elif layer_k in {'lane_channels'}: ++ # for lane_channel in layer_v: ++ # linepatch = PathPatch(Path(lane_channel.left), ++ # fill=False, ++ # color='blue', ++ # linestyle='-') ++ # # linewidth=2) ++ # plt.gca().add_patch(linepatch) ++ # linepatch = PathPatch(Path(lane_channel.right), ++ # fill=False, ++ # color='red', ++ # linestyle='-') ++ # # linewidth=2) ++ # plt.gca().add_patch(linepatch) ++ ++ # plot local coordinate system origin ++ if poses is None: ++ local_coords = np.array([[0.0, 0.0, 0.0, 1.0], [arr_len, 0.0, 0.0, 1.0], [0.0, arr_len, 0.0, 1.0]]) ++ world_coords = np.dot(seg.local2world, local_coords.T).T ++ world_coords = world_coords[:,:2] # only plot 2D coords ++ origins.append(world_coords[0]) ++ xdelta = world_coords[1] - world_coords[0] ++ ydelta = world_coords[2] - world_coords[0] ++ plt.arrow(world_coords[0, 0], world_coords[0, 1], xdelta[0], xdelta[1], color='red') ++ plt.arrow(world_coords[0, 0], world_coords[0, 1], ydelta[0], ydelta[1], color='green') ++ ++ if poses is not None and poses_valid is not None and poses_lwh is not None: ++ # center on ego ++ origin = poses[0,0,:2,-1] ++ if RIVERMARK_VIZ: ++ extent = 45 ++ origin = origin + np.array([extent - 15.0, 0.0]) ++ ++ # if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER > 3: ++ # # plot ego traj ++ # ego_pos = ego_traj[::10,:2,3] ++ # plt.plot(ego_pos[:,0], ego_pos[:,1], 'o-', c=NVCOLOR2, markersize=2.5) #, markersize=8), linewidth ++ ++ plt.xlim(origin[0]-extent, origin[0]+extent) ++ plt.ylim(origin[1]-extent, origin[1]+extent) ++ for n in range(poses.shape[0]): ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 1: ++ continue ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 3 and n != 0: ++ continue ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 4 and pose_ids[n] == 'extra': ++ continue ++ # if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER == 3 and n == 0: ++ # continue ++ cur_color = plt_color((n+2) % 9) ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER >= 4 and pose_ids[n] == 'extra': ++ cur_color = '#ff00ff' ++ if RIVERMARK_VIZ: ++ print(n) ++ print(pose_ids[n]) ++ print(cur_color) ++ cur_poses = poses[n] #, ::20] ++ xy = cur_poses[:,:2,3] ++ hvec = cur_poses[:,:2,0] # x axis ++ hvec = hvec / np.linalg.norm(hvec, axis=1, keepdims=True) ++ for t in range(cur_poses.shape[0]): ++ if poses_valid[n, t]: ++ plot_box(np.array([xy[t,0], xy[t,1], hvec[t,0], hvec[t,1]]), poses_lwh[n,:2], ++ color=NVCOLOR if n ==0 else cur_color, alpha=1.0, no_heading=False) ++ if pose_ids is not None and not RIVERMARK_VIZ: ++ plt.text(xy[t,0] + 1.0, xy[t,1] + 1.0, pose_ids[n], c='red', fontsize='x-small') ++ ++ plt.gca().set_aspect('equal') ++ plt.grid(grid) ++ if not show_ticks: ++ plt.xticks([]) ++ plt.yticks([]) ++ plt.gca().axis('off') ++ # plt.tight_layout() ++ cur_save_path = os.path.join(comp_out_path, prefix + '.jpg') ++ fig.savefig(cur_save_path) ++ # plt.show() ++ plt.close(fig) ++ ++ if RIVERMARK_VIZ: ++ og_img = cv2.imread(cur_save_path) ++ rot_img = cv2.rotate(og_img, cv2.cv2.ROTATE_90_COUNTERCLOCKWISE) ++ cv2.imwrite(cur_save_path, rot_img) ++ ++ ++def get_tile_mask(tile, layer_name, local_box, canvas_size): ++ ''' ++ Rasterizes a layer of the given tile object into a binary mask. ++ Assumes tile object has been converted to hold nuscenes-like layers. ++ ++ :param tile: NVMapTile object holding the nuscenes layers ++ :param layer_name str: which layer to rasterize, currently supports ['drivable_area', 'carpark_area', 'road_divider', 'lane_divider'] ++ :param local_box tuple: (center_x, center_y, H, W) in meters which patch of the map to rasterize ++ :param canvas_size tuple: (H, W) pixels tuple which determines the resolution at which the layer is rasterized ++ ''' ++ # must transform each map element to pixel space ++ # https://github.com/nutonomy/nuscenes-devkit/blob/9b209638ef3dee6d0cdc5ac700c493747f5b35fe/python-sdk/nuscenes/map_expansion/map_api.py#L1894 ++ patch_x, patch_y, patch_h, patch_w = local_box ++ canvas_h = canvas_size[0] ++ canvas_w = canvas_size[1] ++ scale_height = canvas_h/patch_h ++ scale_width = canvas_w/patch_w ++ trans_x = -patch_x + patch_w / 2.0 ++ trans_y = -patch_y + patch_h / 2.0 ++ trans = np.array([[trans_x, trans_y]]) ++ scale = np.array([[scale_width, scale_height]]) ++ ++ map_mask = np.zeros(canvas_size, np.uint8) ++ for seg_id, seg in tile.segments.items(): ++ for poly_pts in seg.layers[layer_name]: ++ # convert to pixel coords ++ poly_pts = (poly_pts + trans)*scale ++ # rasterize ++ if layer_name in {'drivable_area'}: ++ # polygon ++ coords = poly_pts.astype(np.int32) ++ cv2.fillPoly(map_mask, [coords], 1) ++ elif layer_name in {'lane_divider', 'road_divider'}: ++ # polyline ++ coords = poly_pts.astype(np.int32) ++ cv2.polylines(map_mask, [coords], False, 1, 2) ++ elif layer_name in {'carpark_area'}: ++ # empty ++ pass ++ else: ++ print('Unrecognized layer %d - cannot render mask' % (layer_name)) ++ ++ return map_mask ++ ++# https://www.dmv.ca.gov/portal/handbook/california-driver-handbook/lane-control/ ++# - solid yellow lines = center of road for two-way (road divider) ++# - two, one solid one broken yellow = may pass but going opposite dir (road divider) ++# - two solid yellow = road divider ++# - solid white = edge of road going same way (lane divider) ++# - broken white = two or more lanes same direction (lane divider) ++# - double white = HOV (lane divider) ++# - invisible should be ommitted (e.g. will just be in intersections) ++ ++def convert_tile_to_nuscenes(tile): ++ ''' ++ Given a tile, converts its layers into similar format as nuscenes. ++ This includes converting lane dividers and road boundaries to lane/road dividers and ++ drivable area. ++ returns an updated copy of the tile. ++ ''' ++ print('Converting to nuscenes...') ++ tile = copy.deepcopy(tile) ++ for seg_id, seg in tile.segments.items(): ++ if 'lane_dividers' in seg.layers: ++ nusc_lane_dividers = [] ++ nusc_road_dividers = [] ++ for div in seg.layers['lane_dividers']: ++ style = div.style ++ if len(style) > 0: ++ div_color = style[0][2] ++ div_pattern = style[0][0] ++ if div_color in {'White'}: #, 'Green'}: ++ # divides traffic in same direction ++ nusc_lane_dividers.append(div.polyline[:,:2]) ++ elif div_color in {'Yellow'} or (len(style) == 2 and div_pattern == 'Botts Dots'): #, 'Blue', 'Red', 'Orange'}: ++ # divides traffic in opposite direction ++ nusc_road_dividers.append(div.polyline[:,:2]) ++ # elif RIVERMARK_VIZ: ++ # nusc_lane_dividers.append(div.polyline[:,:2]) ++ # update segment ++ seg.layers['lane_divider'] = nusc_lane_dividers ++ seg.layers['road_divider'] = nusc_road_dividers ++ ++ del seg.layers['lane_dividers'] ++ ++ if 'road_boundaries' in seg.layers: ++ del seg.layers['road_boundaries'] ++ pass # actually don't need this for now, it's just for reference ++ ++ if 'lane_channels' in seg.layers: ++ # convert to drivable area polygon ++ expand_len = NUSC_EXPAND_LEN # meters ++ drivable_area_polys = [] ++ for channel in seg.layers['lane_channels']: ++ left = channel.left[:,:2] ++ right = channel.right[:,:2] ++ if RIVERMARK_VIZ: ++ seg.layers['lane_divider'].append(left) ++ seg.layers['road_divider'].append(right) ++ if left.shape[0] > 1: ++ # compute normals at each vertex to expand ++ left_diff = left[1:] - left[:-1] ++ left_diff = np.concatenate([left_diff, left_diff[-1:,:]], axis=0) ++ left_diff = left_diff / np.linalg.norm(left_diff, axis=1, keepdims=True) ++ left_norm = np.concatenate([-left_diff[:,1:2], left_diff[:,0:1]], axis=1) ++ right_diff = right[1:] - right[:-1] ++ right_diff = np.concatenate([right_diff, right_diff[-1:,:]], axis=0) ++ right_diff = right_diff / np.linalg.norm(right_diff, axis=1, keepdims=True) ++ right_norm = np.concatenate([right_diff[:,1:2], -right_diff[:,0:1]], axis=1) ++ # expand channel ++ left = left + (left_norm * expand_len) ++ right = right + (right_norm * expand_len) ++ ++ channel_poly = np.concatenate([right, np.flip(left, axis=0)], axis=0) ++ drivable_area_polys.append(channel_poly) ++ seg.layers['drivable_area'] = drivable_area_polys ++ del seg.layers['lane_channels'] ++ ++ # add empty carpark area for completeness ++ seg.layers['carpark_area'] = [] ++ ++ # compute extents of map ++ map_maxes = np.array([-float('inf'), -float('inf')]) # xmax, ymax ++ map_mins = np.array([float('inf'), float('inf')]) # xmin, ymin ++ for _, seg in tile.segments.items(): ++ for k, v in seg.layers.items(): ++ if len(v) > 0: ++ all_pts = np.concatenate(v, axis=0) ++ cur_maxes = np.amax(all_pts, axis=0) ++ cur_mins = np.amin(all_pts, axis=0) ++ map_maxes = np.where(cur_maxes > map_maxes, cur_maxes, map_maxes) ++ map_mins = np.where(cur_mins < map_mins, cur_mins, map_mins) ++ map_xlim = (map_mins[0] - 10, map_maxes[0] + 10) # buffer of 10m ++ map_ylim = (map_mins[1] - 10, map_maxes[1] + 10) # buffer of 10m ++ W = map_xlim[1] - map_xlim[0] ++ H = map_ylim[1] - map_ylim[0] ++ # translate so bottom left corner is at origin ++ trans_offset = np.array([[-map_xlim[0], -map_ylim[0]]]) ++ tile.trans_offset = trans_offset ++ tile.H = H ++ tile.W = W ++ for _, seg in tile.segments.items(): ++ seg.local2world[:2, -1] += trans_offset[0] ++ for k, v in seg.layers.items(): ++ if len(v) > 0: ++ for pts in v: ++ pts += trans_offset ++ ++ return tile ++ ++def load_tile(tile_path, ++ layers=['lane_dividers_v1', 'lane_channels_v1']): ++ # load in all road segment dicts ++ print('Parsing segment graph...') ++ tile_name = tile_path.split('/')[-1] ++ segs_path = os.path.join(tile_path, SEGMENT_GRAPH_LAYER) ++ assert os.path.exists(tile_path), 'cannot find segment graph layer, which is required to load any layers' ++ seg_json_list, _ = load_json_dir(segs_path) ++ road_segments = parse_road_segments(seg_json_list) ++ print('Found %d road segments:' % (len(road_segments))) ++ ++ # fill road segments with other desired layers ++ print('Loading requested layers...') ++ for layer in layers: ++ assert layer in NVMAP_LAYERS, 'loading layer type %s is currently not supported!' % (layer) ++ layer_dir = os.path.join(tile_path, layer) ++ assert os.path.exists(layer_dir), 'could not find requested layer %s in tile directory!' % (layer) ++ layer_json_list, seg_names = load_json_dir(layer_dir) ++ if layer == 'lane_dividers_v1': ++ parse_lane_dividers(layer_json_list, seg_names, road_segments) ++ elif layer == 'lane_channels_v1': ++ parse_lane_channels(layer_json_list, seg_names, road_segments) ++ elif layer == 'lanes_v1': ++ raise NotImplementedError() ++ ++ return NVMapTile(road_segments, name=tile_name) ++ ++ ++def load_json_dir(json_dir_path): ++ ''' ++ Loads in all json files in the given directory and returns the resulting list of dicts along ++ with the names of the json files read from. ++ ''' ++ json_files = sorted(glob.glob(os.path.join(json_dir_path, '*.json'))) ++ file_names = ['.'.join(jf.split('/')[-1].split('.')[:-1]) for jf in json_files] ++ json_list = [] ++ for jf in json_files: ++ with open(jf, 'r') as f: ++ json_list.append(json.load(f)) ++ return json_list, file_names ++ ++def parse_pt(pt_dict): ++ pt_entries = ['x', 'y', 'z', 'w'] ++ if len(pt_dict) == 3: ++ pt = [pt_dict[pt_entries[0]], pt_dict[pt_entries[1]], pt_dict[pt_entries[2]]] ++ elif len(pt_dict) == 4: ++ pt = [pt_dict[pt_entries[0]], pt_dict[pt_entries[1]], pt_dict[pt_entries[2]], pt_dict[pt_entries[3]]] ++ else: ++ assert False, 'input point must be length 3 or 4' ++ return pt ++ ++def parse_pt_list(pt_list): ++ return np.array([parse_pt(pt) for pt in pt_list]) ++ ++def parse_lane_channels(lane_channels_json_list, seg_name_list, road_segments): ++ ''' ++ Parses lane channels in each segment to an object, and store in its respective segment. ++ Transforms channels into the global coordinate system. ++ ++ :param lane_channels_json_list: list of lane channels json dicts ++ :param seg_name_list: the name of the segment corresponding to each json file in lane_div_json_list ++ :param road_segments: dict of all road segments in a tile ++ :return: updated road_segments (also updated in place) ++ ''' ++ for channel_dict, seg_name in zip(lane_channels_json_list, seg_name_list): ++ cur_seg = road_segments[seg_name] ++ # load lane dividers ++ lane_channels = [] ++ lane_channel_dicts = channel_dict['channels'] ++ for channel in lane_channel_dicts: ++ # left ++ left_geom = channel['left_side']['geometry'][0]['chunk']['channel_edge_line'] ++ left_polyline = parse_pt_list(left_geom['points']) ++ left_polyline = cur_seg.to_global(left_polyline)[:,:3] ++ # right ++ right_geom = channel['right_side']['geometry'][0]['chunk']['channel_edge_line'] ++ right_polyline = parse_pt_list(right_geom['points']) ++ right_polyline = cur_seg.to_global(right_polyline)[:,:3] ++ assert left_polyline.shape[0] == right_polyline.shape[0], 'channel edges should be same length!' ++ # build lane channel object ++ lane_channels.append(LaneChannel(left_polyline, right_polyline)) ++ cur_seg.layers['lane_channels'] = lane_channels ++ ++ return road_segments ++ ++def parse_lane_dividers(lane_div_json_list, seg_name_list, road_segments): ++ ''' ++ Parses lane dividers in each segment to an object, and store in its respective segment. ++ Transforms dividers into the global coordinate system. ++ ++ :param lane_div_json_list: list of lane divider json dicts ++ :param seg_name_list: the name of the segment corresponding to each json file in lane_div_json_list ++ :param road_segments: dict of all road segments in a tile ++ :return: updated road_segments (also updated in place) ++ ''' ++ for div_dict, seg_name in zip(lane_div_json_list, seg_name_list): ++ cur_seg = road_segments[seg_name] ++ # load lane dividers ++ lane_divs = [] ++ lane_div_dicts = div_dict['dividers'] ++ for div in lane_div_dicts: ++ # NOTE: left/right/height lines also sometimes available - ignoring for now ++ # NOTE: lane divider styles (type/color) also available - ignoring for now ++ # center_line is only guaranteed ++ lane_geom = div['geometry']['center_line'] ++ lane_polyline = parse_pt_list(lane_geom['points']) ++ lane_polyline = cur_seg.to_global(lane_polyline)[:,:3] ++ # parse divider style ++ if 'style' in div: ++ lane_styles = div['style'] ++ lane_styles = [(style['pattern'], style['material'], style['color']) for style in lane_styles] ++ else: ++ lane_styles = [] ++ # build lane div object ++ lane_divs.append(LaneDivider(lane_polyline, lane_styles)) ++ cur_seg.layers['lane_dividers'] = lane_divs ++ ++ # load road boundaries ++ road_bounds = [] ++ road_bound_dicts = div_dict['road_boundaries'] ++ for div in road_bound_dicts: ++ # NOTE: left/right/height lines also sometimes available - ignoring for now ++ # NOTE: road boundary type also available - ignoring for now ++ # center_line is only guaranteed ++ bound_geom = div['geometry']['center_line'] ++ bound_polyline = parse_pt_list(bound_geom['points']) ++ bound_polyline = cur_seg.to_global(bound_polyline)[:,:3] ++ # parse boundary type ++ if 'type' in div: ++ bound_type = div['type'] ++ else: ++ bound_type = None ++ # build object ++ road_bounds.append(RoadBoundary(bound_polyline, bound_type)) ++ cur_seg.layers['road_boundaries'] = road_bounds ++ ++ return road_segments ++ ++def parse_road_segments(seg_json_list): ++ ''' ++ Parses road segments into objects, and transforms into a shared coordinate system. ++ ++ :param seg_json_list: list of road segment json dicts ++ ++ :return: dict of all road segments mapping id -> RoadSegment ++ ''' ++ # build objects for all road segments ++ road_segments = OrderedDict() ++ for seg_dict in seg_json_list: ++ cur_seg = build_road_seg(seg_dict) ++ road_segments[cur_seg.id] = cur_seg ++ ++ # go through and annotate local2world by converting GPS to the "global" coordinate system ++ for seg_id, seg in road_segments.items(): ++ lat_lng_alt = np.array(seg.gps_origin).reshape((-1,3)) ++ rot_axis = np.array([1.0, 0.0, 0.0]).reshape((-1,3)) ++ rot_angle = np.array([0.0]).reshape((-1,1)) ++ ecef_pose = lat_lng_alt_2_ecef(lat_lng_alt, rot_axis, rot_angle, 'WGS84')[0] ++ seg.local2world = np.linalg.inv(GLOBAL_BASE_POSE) @ ecef_pose ++ ++ return road_segments ++ ++def build_road_seg(seg_dict): ++ ''' ++ Parses road segment json dictionary into an object. ++ ''' ++ segment = seg_dict['segment'] ++ seg_id = segment['id'] ++ seg_origin = [segment['origin']['lat'], segment['origin']['lon'], segment['origin']['height']] ++ connections = [] ++ if 'connections' in segment: ++ conn_list = segment['connections'] ++ for conn in conn_list: ++ source_id = conn['source_id'] ++ source2tgt = [] ++ for ci in range(4): ++ source2tgt.append(np.array(parse_pt(conn['source_to_target'][f'column_{ci}']))) ++ source2tgt = np.stack(source2tgt, axis=1) ++ connections.append((source_id, source2tgt)) ++ return RoadSegment(seg_id, seg_origin, connections) ++ ++def collect_seg_origins(road_segments): ++ ''' ++ Returns np array of all road_segment origins (in order of dict). ++ :param road_segments: OrderedDict of RoadSegment objects ++ ''' ++ return np.array([seg.origin for _, seg in road_segments.items()]) ++ ++class RoadSegment(object): ++ def __init__(self, seg_id, origin, connections, ++ is_root=False, ++ layers=None): ++ ''' ++ :param seg_id str: ++ :param origin: list of [lat, lon, height] ++ :param connections list: list of tuples, each containing (neighbor_id, neighbor2local_transform) ++ where the transform is a np.array(4,4)) ++ :param layers dict: layer objects within this road segment ++ ''' ++ self.id = seg_id ++ self.gps_origin = origin # GPS ++ self.connections = connections ++ self.layers = layers if layers is not None else dict() ++ # transformation matrix from local to world frame ++ self.local2world = None ++ ++ def to_global(self, pts): ++ ''' ++ Transform an array of points from this segment's frame to global. ++ ++ :param pts: np array (N x 3) ++ ''' ++ pts = np.concatenate([pts, np.ones((pts.shape[0], 1))], axis=1) ++ pts = np.dot(self.local2world, pts.T).T ++ return pts ++ ++ def transform(self, mat): ++ ''' ++ Transforms this segment and all contained layers by the given 4x4 transformation matrix. ++ ''' ++ self.local2world = mat @ self.local2world ++ for _, layer in self.layers.items(): ++ for el in layer: ++ el.transform(mat) ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class LaneDivider(object): ++ def __init__(self, polyline, style): ++ ''' ++ :param polyline: np.array Nx3 defining the divider geometry ++ :param style: list of tuples of (pattern, matterial, color) ++ ''' ++ self.polyline = polyline ++ self.style = style ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.polyline, np.ones((self.polyline.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.polyline = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class RoadBoundary(object): ++ def __init__(self, polyline, bound_type): ++ ''' ++ :param polyline: np.array Nx3 defining the divider geometry ++ :param type str: type of the boundary ++ ''' ++ self.polyline = polyline ++ self.bound_type = bound_type ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.polyline, np.ones((self.polyline.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.polyline = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class LaneChannel(object): ++ def __init__(self, left_polyline, right_polyline): ++ ''' ++ :param left_polyline: np.array Nx3 defining the channel left edge geometry ++ :param right_polyline: np.array Nx3 defining the channel right edge geometry ++ ''' ++ self.left = left_polyline ++ self.right = right_polyline ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.left, np.ones((self.left.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.left = pts[:,:3] ++ pts = np.concatenate([self.right, np.ones((self.right.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.right = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class NVMapTile(object): ++ def __init__(self, segments, ++ name=None): ++ ''' ++ :param segments: dict mapping seg_id -> RoadSegment objects for all road segments in the tile. ++ ''' ++ self.segments = segments ++ self.name = name ++ ++ def transform(self, mat): ++ ''' ++ Multiply all elements of the map tile by the given 4x4 matrix ++ ''' ++ for seg_id, seg in self.segments.items(): ++ seg.transform(mat) ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++ ++####################################################################################################################### ++ ++# ++# Utils from Zan for converting between GPS and world coordinates ++# ++ ++from scipy.spatial.transform import Rotation as R ++ ++ ++def lat_lng_alt_2_ecef(lat_lng_alt, orientation_axis, orientation_angle, earth_model='WGS84'): ++ ''' Computes the transformation from the world pose coordiante system to the earth centered earth fixed (ECEF) one ++ Args: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ orientation_axis (np.array): orientation in the local ENU coordinate system [n,3] ++ orientation_angle (np.array): orientation angle of the local ENU coordinate system in degrees [n,1] ++ earth_model (string): earth model used for conversion (spheric will be unaccurate when maps are large) ++ Out: ++ trans (np.array): transformation parameters from world pose to ECEF coordinate system in se3 form (n, 4, 4) ++ ''' ++ n = lat_lng_alt.shape[0] ++ trans = np.tile(np.eye(4).reshape(1,4,4),[n,1,1]) ++ ++ theta = (90. - lat_lng_alt[:, 0]) * np.pi/180 ++ phi = lat_lng_alt[:, 1] * np.pi/180 ++ ++ R_enu_ecef = local_ENU_2_ECEF_orientation(theta, phi) ++ ++ if earth_model == 'WGS84': ++ a = 6378137.0 ++ flattening = 1.0 / 298.257223563 ++ b = a * (1.0 - flattening) ++ translation = lat_lng_alt_2_translation_ellipsoidal(lat_lng_alt, a, b) ++ ++ elif earth_model == 'sphere': ++ earth_radius = 6378137.0 # Earth radius in meters ++ z_dir = np.concatenate([(np.sin(theta)*np.cos(phi))[:,None], ++ (np.sin(theta)*np.sin(phi))[:,None], ++ (np.cos(theta))[:,None] ],axis=1) ++ ++ translation = (earth_radius + lat_lng_alt[:, -1])[:,None] * z_dir ++ ++ else: ++ raise ValueError ("Selected ellipsoid not implemented!") ++ ++ world_pose_orientation = axis_angle_2_so3(orientation_axis, orientation_angle) ++ ++ trans[:,:3,:3] = R_enu_ecef @ world_pose_orientation ++ trans[:,:3,3] = translation ++ ++ return trans ++ ++def local_ENU_2_ECEF_orientation(theta, phi): ++ ''' Computes the rotation matrix between the world_pose and ECEF coordinate system ++ Args: ++ theta (np.array): theta coordinates in radians [n,1] ++ phi (np.array): phi coordinates in radians [n,1] ++ Out: ++ (np.array): rotation from world pose to ECEF in so3 representation [n,3,3] ++ ''' ++ z_dir = np.concatenate([(np.sin(theta)*np.cos(phi))[:,None], ++ (np.sin(theta)*np.sin(phi))[:,None], ++ (np.cos(theta))[:,None] ],axis=1) ++ z_dir = z_dir/np.linalg.norm(z_dir, axis=-1, keepdims=True) ++ ++ y_dir = np.concatenate([-(np.cos(theta)*np.cos(phi))[:,None], ++ -(np.cos(theta)*np.sin(phi))[:,None], ++ (np.sin(theta))[:,None] ],axis=1) ++ y_dir = y_dir/np.linalg.norm(y_dir, axis=-1, keepdims=True) ++ ++ x_dir = np.cross(y_dir, z_dir) ++ ++ return np.concatenate([x_dir[:,:,None], y_dir[:,:,None], z_dir[:,:,None]], axis = -1) ++ ++ ++def lat_lng_alt_2_translation_ellipsoidal(lat_lng_alt, a, b): ++ ''' Computes the translation based on the ellipsoidal earth model ++ Args: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ a (float/double): Semi-major axis of the ellipsoid ++ b (float/double): Semi-minor axis of the ellipsoid ++ Out: ++ (np.array): translation from world pose to ECEF [n,3] ++ ''' ++ ++ phi = lat_lng_alt[:, 0] * np.pi/180 ++ gamma = lat_lng_alt[:, 1] * np.pi/180 ++ ++ cos_phi = np.cos(phi) ++ sin_phi = np.sin(phi) ++ cos_gamma = np.cos(gamma) ++ sin_gamma = np.sin(gamma) ++ e_square = (a * a - b * b) / (a * a) ++ ++ N = a / np.sqrt(1 - e_square * sin_phi * sin_phi) ++ ++ ++ x = (N + lat_lng_alt[:, 2]) * cos_phi * cos_gamma ++ y = (N + lat_lng_alt[:, 2]) * cos_phi * sin_gamma ++ z = (N * (b*b)/(a*a) + lat_lng_alt[:, 2]) * sin_phi ++ ++ return np.concatenate([x[:,None] ,y[:,None], z[:,None]], axis=1 ) ++ ++def axis_angle_2_so3(axis, angle, degrees=True): ++ ''' Converts the axis angle representation of the so3 rotation matrix ++ Args: ++ axis (np.array): the rotation axes [n,3] ++ angle float/double: rotation angles either in degrees or radians [n] ++ degrees bool: True if angle is given in degrees else False ++ ++ Out: ++ (np array): rotations given so3 matrix representation [n,3,3] ++ ''' ++ # Treat angle (radians) below this as 0. ++ cutoff_angle = 1e-9 if not degrees else 1e-9*180/np.pi ++ angle[angle < cutoff_angle] = 0.0 ++ ++ # Scale the axis to have the norm representing the angle ++ if degrees: ++ angle = np.radians(angle) ++ axis_angle = (angle/np.linalg.norm(axis, axis=1, keepdims=True)) * axis ++ ++ return R.from_rotvec(axis_angle).as_matrix() ++ ++def ecef_2_lat_lng_alt(trans, earth_model='WGS84'): ++ ''' Converts the transformation from the earth centered earth fixed (ECEF) coordinate frame to the world pose ++ Args: ++ trans (np.array): transformation parameters in ECEF [n,4,4] ++ earth_model (string): earth model used for conversion (spheric will be unaccurate when maps are large) ++ Out: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ orientation_axis (np.array): orientation in the local ENU coordinate system [n,3] ++ orientation_angle (np.array): orientation angle of the local ENU coordinate system in degrees [n,1] ++ ''' ++ ++ translation = trans[:,:3,3] ++ rotation = trans[:,:3,:3] ++ ++ if earth_model == 'WGS84': ++ a = 6378137.0 ++ flattening = 1.0 / 298.257223563 ++ lat_lng_alt = translation_2_lat_lng_alt_ellipsoidal(translation, a, flattening) ++ ++ elif earth_model == 'sphere': ++ earth_radius = 6378137.0 # Earth radius in meters ++ lat_lng_alt = translation_2_lat_lng_alt_spherical(translation, earth_radius) ++ ++ else: ++ raise ValueError ("Selected ellipsoid not implemented!") ++ ++ ++ # Compute the orientation axis and angle ++ theta = (90. - lat_lng_alt[:, 0]) * np.pi/180 ++ phi = lat_lng_alt[:, 1] * np.pi/180 ++ ++ R_ecef_enu = local_ENU_2_ECEF_orientation(theta, phi).transpose(0,2,1) ++ ++ orientation = R_ecef_enu @ rotation ++ orientation_axis, orientation_angle = so3_2_axis_angle(orientation) ++ ++ ++ return lat_lng_alt, orientation_axis, orientation_angle ++ ++def translation_2_lat_lng_alt_spherical(translation, earth_radius): ++ ''' Computes the translation in the ECEF to latitude, longitude, altitude based on the spherical earth model ++ Args: ++ translation (np.array): translation in the ECEF coordinate frame (in meters) [n,3] ++ earth_radius (float/double): earth radius ++ Out: ++ (np.array): latitude, longitude and altitude [n,3] ++ ''' ++ altitude = np.linalg.norm(translation, axis=-1) - earth_radius ++ latitude = 90 - np.arccos(translation[:,2] / np.linalg.norm(translation, axis=-1, keepdims=True)) * 180/np.pi ++ longitude = np.arctan2(translation[:,1],translation[:,0]) * 180/np.pi ++ ++ return np.concatenate([latitude[:,None], longitude[:,None], altitude[:,None]], axis=1) ++ ++def translation_2_lat_lng_alt_ellipsoidal(translation, a, f): ++ ''' Computes the translation in the ECEF to latitude, longitude, altitude based on the ellipsoidal earth model ++ Args: ++ translation (np.array): translation in the ECEF coordinate frame (in meters) [n,3] ++ a (float/double): Semi-major axis of the ellipsoid ++ f (float/double): flattening factor of the earth ++ radius ++ Out: ++ (np.array): latitude, longitude and altitude [n,3] ++ ''' ++ ++ # Compute support parameters ++ f0 = (1 - f) * (1 - f) ++ f1 = 1 - f0 ++ f2 = 1 / f0 - 1 ++ ++ z_div_1_f = translation[:,2] / (1 - f) ++ x2y2 = np.square(translation[:,0]) + np.square(translation[:,1]) ++ ++ x2y2z2 = x2y2 + z_div_1_f*z_div_1_f ++ x2y2z2_pow_3_2 = x2y2z2 * np.sqrt(x2y2z2) ++ ++ gamma = (x2y2z2_pow_3_2 + a * f2 * z_div_1_f * z_div_1_f) / (x2y2z2_pow_3_2 - a * f1 * x2y2) * translation[:,2] / np.sqrt(x2y2) ++ ++ longitude = np.arctan2(translation[:,1], translation[:,0]) * 180/np.pi ++ latitude = np.arctan(gamma) * 180/np.pi ++ altitude = np.sqrt(1 + np.square(gamma)) * (np.sqrt(x2y2) - a / np.sqrt(1 + f0 * np.square(gamma))) ++ ++ return np.concatenate([latitude[:,None], longitude[:,None], altitude[:,None]], axis=1) ++ ++def so3_2_axis_angle(so3, degrees=True): ++ ''' Converts the so3 representation to axis_angle ++ Args: ++ so3 (np.array): the rotation matrices [n,3,3] ++ degrees bool: True if angle should be given in degrees ++ ++ Out: ++ axis (np array): the rotation axis [n,3] ++ angle (np array): the rotation angles, either in degrees (if degrees=True) or radians [n,] ++ ''' ++ rot_vec = R.from_matrix(so3).as_rotvec() ++ ++ angle = np.linalg.norm(rot_vec, axis=-1, keepdims=True) ++ axis = rot_vec / angle ++ if degrees: ++ angle = np.degrees(angle) ++ ++ return axis, angle ++ ++####################################################################################################################### ++ ++# ++# Utils for loading in ego and autolabel pose data for session ++# ++ ++import datetime ++import pickle ++ ++from scipy import spatial, interpolate ++ ++# NV_EGO_LWH = [4.084, 1.73, 1.562] # this is the nuscenes measurements ++NV_EGO_LWH = [5.30119, 2.1133, 1.49455] # actual hyperion 8 ++ ++class PoseInterpolator: ++ ''' Interpolates the poses to the desired time stamps. The translation component is interpolated linearly, ++ while spherical linear interpolation (SLERP) is used for the rotations. ++ https://en.wikipedia.org/wiki/Slerp ++ ++ Args: ++ poses (np.array): poses at given timestamps in a se3 representation [n,4,4] ++ timestamps (np.array): timestamps of the known poses [n] ++ ts_target (np.array): timestamps for which the poses will be interpolated [m,1] ++ Out: ++ (np.array): interpolated poses in se3 representation [m,4,4] ++ ''' ++ def __init__(self, poses, timestamps): ++ ++ self.slerp = spatial.transform.Slerp(timestamps, R.from_matrix(poses[:,:3,:3])) ++ self.f_x = interpolate.interp1d(timestamps, poses[:,0,3]) ++ self.f_y = interpolate.interp1d(timestamps, poses[:,1,3]) ++ self.f_z = interpolate.interp1d(timestamps, poses[:,2,3]) ++ ++ self.last_row = np.array([0,0,0,1]).reshape(1,1,-1) ++ ++ def interpolate_to_timestamps(self, ts_target): ++ x_interp = self.f_x(ts_target).reshape(-1,1,1) ++ y_interp = self.f_y(ts_target).reshape(-1,1,1) ++ z_interp = self.f_z(ts_target).reshape(-1,1,1) ++ R_interp = self.slerp(ts_target).as_matrix().reshape(-1,3,3) ++ ++ t_interp = np.concatenate([x_interp,y_interp,z_interp],axis=-2) ++ ++ return np.concatenate((np.concatenate([R_interp,t_interp],axis=-1), np.tile(self.last_row,(R_interp.shape[0],1,1))), axis=1) ++ ++def angle_diff(x, y, period=2*np.pi): ++ ''' ++ Get the smallest angle difference between 2 angles: the angle from y to x. ++ :param x: angle 1 (B) ++ :param y: angle 2 (B) ++ :param period: periodicity in radians for assessing difference. ++ :return diff: smallest angle difference between to angles (B) ++ ''' ++ # calculate angle difference, modulo to [0, 2*pi] ++ diff = (x - y + period / 2) % period - period / 2 ++ diff[diff > np.pi] = diff[diff > np.pi] - (2 * np.pi) # shift (pi, 2*pi] to (-pi, 0] ++ return diff ++ ++def load_ego_pose_from_image_meta(images_path, map_tile=None): ++ ''' ++ Loads in the SDC pose from the metadata attached to a session image stream. ++ ++ :param images_path str: directory of the images/metadata. Should contain *.pkl for each frame and timestamps.npz ++ :param map_tile: Tile object, if given translates the ego trajectory so in the same frame as this map tile. ++ ''' ++ frame_meta = sorted(glob.glob(os.path.join(images_path, '*.pkl'))) ++ timestamps_pth = os.path.join(images_path, 'timestamps.npz') ++ ++ # load in timesteps ++ # ego_t = np.load(timestamps_pth)['frame_t'] ++ ++ ego_poses = [] ++ ego_t = [] ++ for meta_file in frame_meta: ++ with open(meta_file, 'rb') as f: ++ cur_meta = pickle.load(f) ++ # ego_poses.append(cur_meta['ego_pose_s']) ++ # ego_poses.append(cur_meta['ego_pose_timestamps'][0]) ++ ego_poses.append(cur_meta['ego_pose_e']) ++ ego_t.append(cur_meta['ego_pose_timestamps'][1]) ++ ego_poses = np.stack(ego_poses, axis=0) ++ ego_t = np.array(ego_t) ++ ++ # pose_sidx = int(frame_meta[0].split('/')[-1].split('.')[0]) ++ # pose_eidx = int(frame_meta[-1].split('/')[-1].split('.')[0]) + 1 ++ # ego_t = ego_t[pose_sidx:pose_eidx] ++ ++ if map_tile is not None: ++ ego_poses[:, :2, -1] += map_tile.trans_offset ++ ++ return ego_poses, ego_t ++ ++def check_time_overlap(s0, e0, s1, e1): ++ overlap = (s0 < s1 and e0 > s1) or \ ++ (s0 > s1 and s0 < e1) or \ ++ (s1 < s0 and e1 > s0) or \ ++ (s1 > s0 and s1 < e0) ++ return overlap ++ ++ ++def load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ map_tile=None, ++ frame_range=None, ++ postprocess=True, ++ fill_first_n=None, ++ mine_dups=False, ++ extra_obj_path=None, ++ load_ids=None, ++ crop2valid=True): ++ ''' ++ This only loads labeled trajectories that are available at the same section ++ as the ego labels. ++ ++ :param autolabels_path str: pkl file to load autolabels from ++ :param ego_images_path str: directory containing image metadata to load ego poses from ++ :param lidar_rig_path str: npz containing lidar2rig transform for ego ++ :param map_tile: Tile object, if given translates the trajectories so in the same frame as this map tile. ++ :param frame_range tuple: If given (start, end) only loads in this frame range (wrt the ego sequence) ++ :param postprocess bool: If true, runs some post-processing to associate track and heuristically remove rotation flips. ++ :param extra_obj_path str: if given, load in an additional trajectory from here and includes in the data ++ :return traj_T: ++ ''' ++ # Load the autolabels ++ with open(autolabels_path, 'rb') as f: ++ labels = pickle.load(f) ++ ++ # Load the poses and the timestamps ++ ego_poses, ego_pose_timestamps = load_ego_pose_from_image_meta(ego_images_path) ++ if frame_range is not None: ++ assert frame_range[1] > frame_range[0] ++ assert (frame_range[0] >= 0 and frame_range[0] <= ego_poses.shape[0]) ++ ego_poses = ego_poses[frame_range[0]:frame_range[1]] ++ ego_pose_timestamps = ego_pose_timestamps[frame_range[0]:frame_range[1]] ++ ++ # if ego_poses are not in rivermark frame, need to take them so can load autolabels ++ ego_poses_rivermark = ego_poses ++ if GLOBAL_BASE_POSE is not GLOBAL_BASE_POSE_RIVERMARK: ++ ego_poses_ecef = np.matmul(GLOBAL_BASE_POSE[np.newaxis], ego_poses) ++ ego_poses_rivermark = np.matmul(np.linalg.inv(GLOBAL_BASE_POSE_RIVERMARK)[np.newaxis], ego_poses_ecef) ++ ++ # Load the lidar to rig transformation parameters and timestamps ++ T_lidar_rig = np.load(lidar_rig_path)['T_lidar_rig'] ++ ++ # first pass to break tracks into contiguous subsequences ++ # and merge manually given missed associations ++ track_seqs = dict() ++ processed_ids = set() ++ updated_labels = dict() ++ for track_id, track in labels.items(): ++ if load_ids is not None and track_id not in load_ids: ++ continue ++ if track_id in processed_ids or track_id in TRACK_REMOVE: ++ # already processed this through association ++ # or should be removed ++ continue ++ ++ obj_ts = track['3D_bbox'][:,0] ++ if track_id in TRACK_ASSOC: ++ # stack all the data from all associated tracks ++ # NOTE: this assumes TRACK_ASSOC is sorted in temporal order already ++ assoc_data = [labels[assoc_track_id]['3D_bbox'] for assoc_track_id in TRACK_ASSOC[track_id]] ++ # if association is wrong, the tracks may overlap ++ valid_assoc = [not check_time_overlap(obj_ts[0], obj_ts[-1], assoc_label[0,0], assoc_label[-1,0]) for assoc_label in assoc_data] ++ if np.sum(valid_assoc) != len(valid_assoc): ++ print('Invalid associations for track_id %s!!' % (track_id)) ++ print('Ignoring: ') ++ print(np.array(TRACK_ASSOC[track_id])[~np.array(valid_assoc)]) ++ assoc_data = [assoc_label for aid, assoc_label in enumerate(assoc_data) if valid_assoc[aid]] ++ if len(assoc_data) > 0: ++ assoc_bbox = np.concatenate([track['3D_bbox']] + assoc_data, axis=0) ++ updated_labels[track_id] = {'3D_bbox' : assoc_bbox, 'type' : track['type']} ++ obj_ts = assoc_bbox[:,0] ++ processed_ids.update(TRACK_ASSOC[track_id]) ++ else: ++ updated_labels[track_id] = track ++ else: ++ updated_labels[track_id] = track ++ ++ if len(obj_ts) < 2: ++ # make sure track is longer than single frame ++ continue ++ track_seqs[track_id] = [] ++ # larger than 3 timesteps considered a break, otherwise should be reasonable to interpolate ++ # should we do even larger? ++ track_break = np.diff(1e-6*obj_ts) > (AUTOLABEL_DT*3 + AUTOLABEL_DT*0.5) ++ seq_sidx = 0 ++ for tidx in range(1, obj_ts.shape[0]): ++ if track_break[tidx-1]: ++ track_seqs[track_id].append((seq_sidx, tidx)) ++ seq_sidx = tidx ++ track_seqs[track_id].append((seq_sidx, obj_ts.shape[0])) ++ processed_ids.add(track_id) ++ ++ # load extra object ++ if extra_obj_path is not None: ++ extra_obj_data = np.load(extra_obj_path) ++ extra_obj_poses = extra_obj_data['poses'] ++ extra_obj_timestamps = extra_obj_data['pose_timestamps'] ++ extra_obj_lwh = EXTRA_DYN_LWH ++ extra_track = { ++ 'poses' : extra_obj_poses, ++ 'timestamps' : extra_obj_timestamps, ++ 'lwh' : extra_obj_lwh, ++ 'type' : 'car' ++ } ++ updated_labels['extra'] = extra_track ++ track_seqs['extra'] = [(0,extra_obj_poses.shape[0])] ++ ++ # collect all tracks that overlap with ego data ++ traj_poses = [] ++ traj_valid = [] ++ traj_lwh = [] ++ traj_ids = [] ++ for track_id, cur_track_seqs in track_seqs.items(): ++ track = updated_labels[track_id] ++ if track_id == 'extra': ++ all_obj_ts = track['timestamps'] ++ all_obj_dat = track['poses'] ++ obj_lwh = track['lwh'] ++ else: ++ all_obj_ts = track['3D_bbox'][:,0] ++ all_obj_dat = track['3D_bbox'] ++ obj_lwh = np.median(all_obj_dat[:,4:7], axis=0) # use all timesteps for bbox size ++ # will fill these in as we go through each subseq ++ full_obj_traj = np.ones_like(ego_poses_rivermark)*np.nan ++ obj_valid = np.zeros((full_obj_traj.shape[0]), dtype=bool) ++ for seq_sidx, seq_eidx in cur_track_seqs: ++ obj_ts = all_obj_ts[seq_sidx:seq_eidx] ++ if (obj_ts[0] >= ego_pose_timestamps[0] and obj_ts[0] <= ego_pose_timestamps[-1]) or \ ++ (obj_ts[-1] >= ego_pose_timestamps[0] and obj_ts[-1] <= ego_pose_timestamps[-1]) or \ ++ (obj_ts[0] <= ego_pose_timestamps[0] and obj_ts[-1] >= ego_pose_timestamps[-1]): ++ obj_type = track['type'] ++ # if obj_type != 'car': ++ # continue ++ obj_dat = all_obj_dat[seq_sidx:seq_eidx] ++ # find steps overlapping the ego sequence ++ valid_ts = np.logical_and(obj_ts >= ego_pose_timestamps[0], obj_ts <= ego_pose_timestamps[-1]) ++ ++ overlap_inds = np.nonzero(valid_ts)[0] ++ if len(overlap_inds) < 2: ++ continue # need more than 1 frame overlap ++ sidx = np.amin(overlap_inds) ++ eidx = np.amax(overlap_inds)+1 ++ ++ obj_ts = obj_ts[sidx:eidx] ++ # some poses have the same timestep -- drop these so we can interpolate ++ valid_t = np.diff(obj_ts) > 0 ++ valid_t = np.append(valid_t, [True]) ++ if not valid_t[0]: ++ # want to keep the edge times in tact since these surround ego times ++ valid_t[0] = True ++ valid_t[1] = False ++ obj_ts = obj_ts[valid_t] ++ ++ if track_id == 'extra': ++ print(obj_ts) ++ glob_obj_poses = obj_dat[sidx:eidx] ++ print(glob_obj_poses.shape) ++ # exit() ++ else: ++ obj_pos = obj_dat[sidx:eidx,1:4][valid_t] ++ # print(obj_pos) ++ # print(obj_dat[sidx:eidx,4:7][valid_t]) ++ obj_rot_eulxyz = obj_dat[sidx:eidx,7:][valid_t] ++ obj_rot_eulxyz[obj_rot_eulxyz[:,2] < -np.pi, 2] += (2 * np.pi) ++ obj_rot_eulxyz[obj_rot_eulxyz[:,2] > np.pi, 2] -= (2 * np.pi) ++ obj_R = R.from_euler('xyz', obj_rot_eulxyz, degrees=False).as_matrix() ++ # build transformation matrix (pose sequence) ++ obj_poses = np.repeat(np.eye(4)[np.newaxis], len(obj_ts), axis=0) ++ obj_poses[:,:3,:3] = obj_R ++ obj_poses[:,:3,-1] = obj_pos ++ ++ # need to interpolate the ego pose to transform from lidar frame to global ++ overlap_ego_mask = np.logical_and(ego_pose_timestamps >= obj_ts[0] - 1e6, ego_pose_timestamps <= obj_ts[-1] + 1e6) # add 1 sec around so can interp first/last frames ++ overlap_ego_t = ego_pose_timestamps[overlap_ego_mask] ++ overlap_ego_poses = ego_poses_rivermark[overlap_ego_mask] ++ ego_interp = PoseInterpolator(overlap_ego_poses, overlap_ego_t) ++ T_rig_global = ego_interp.interpolate_to_timestamps(obj_ts) ++ ++ # transform to rig frame from lidar ++ rig_obj_poses = np.matmul(T_lidar_rig[np.newaxis], obj_poses) ++ ++ # print('elev') ++ # print(rig_obj_poses[:,2, 3]) ++ # print('height') ++ # print(obj_dat[sidx:eidx,6][valid_t]) ++ ++ # transform to global frame (w.r.t rivermark) from rig ++ glob_obj_poses = np.matmul(T_rig_global, rig_obj_poses) ++ # now to the desired global frame ++ if GLOBAL_BASE_POSE is not GLOBAL_BASE_POSE_RIVERMARK: ++ # to ECEF ++ glob_obj_poses = np.matmul(GLOBAL_BASE_POSE_RIVERMARK[np.newaxis], glob_obj_poses) ++ # to desired global pose ++ glob_obj_poses = np.matmul(np.linalg.inv(GLOBAL_BASE_POSE)[np.newaxis], glob_obj_poses) ++ ++ if postprocess and track_id != 'extra': ++ # we're going to collect frames with "correct" orientations ++ # by first looking at dynamic frames where can use motion to infer correct orientation, ++ # then using dynamic to determine correctness of static frames. ++ # then we can interpolate between all these correct frames. ++ glob_hvec = glob_obj_poses[:,:2,0] # x-axis ++ glob_hvec = glob_hvec / np.linalg.norm(glob_hvec, axis=-1, keepdims=True) ++ glob_yaw = np.arctan2(glob_hvec[:,1], glob_hvec[:,0]) ++ glob_pos = glob_obj_poses[:,:2,3] # 2d ++ ++ # TODO add smoothing to the position to avoid noisy velocities ++ obj_vel = np.diff(glob_pos[:,:2], axis=0) / np.diff(obj_ts*1e-6)[:,np.newaxis] ++ obj_vel = np.concatenate([obj_vel, obj_vel[-1:,:]], axis=0) ++ # is_dynamic = np.median(np.linalg.norm(obj_vel, axis=1)) > 2.0 # m/s ++ is_dynamic = np.linalg.norm(obj_vel, axis=1) > 2.0 # m/s ++ ++ is_correct_mask = np.zeros((glob_pos.shape[0]), dtype=bool) ++ ++ # dynamic first ++ if np.sum(is_dynamic) > 0: ++ dynamic_vel = obj_vel[is_dynamic] ++ dynamic_yaw = glob_yaw[is_dynamic] ++ dynamic_vel_norm = np.linalg.norm(dynamic_vel, axis=1, keepdims=True) ++ vel_dir = dynamic_vel / (dynamic_vel_norm + 1e-9) ++ head_dir = np.concatenate([np.cos(dynamic_yaw[:,np.newaxis]), np.sin(dynamic_yaw[:,np.newaxis])], axis=1) ++ vel_head_dot = np.sum(vel_dir * head_dir, axis=1) ++ dynamic_correct = vel_head_dot > 0 ++ is_correct_mask[is_dynamic] = dynamic_correct ++ ++ # now static, by referencing closest correct dynamic ++ dynamic_inds = np.nonzero(np.logical_and(is_dynamic, is_correct_mask))[0] ++ static_inds = np.nonzero(~is_dynamic)[0] ++ if len(static_inds) > 0: ++ # if no dynamic frames ++ # assume correct orientation has the most frequent sign ++ num_pos = np.sum(glob_yaw[~is_dynamic] >= 0) ++ num_neg = np.sum(glob_yaw[~is_dynamic] < 0) ++ for static_ind in static_inds: ++ if len(dynamic_inds) > 0: ++ closest_dyn_ind = np.argmin(np.abs(static_ind - dynamic_inds)) ++ dyn_stat_dot = np.sum(glob_hvec[closest_dyn_ind]*glob_hvec[static_ind]) ++ if dyn_stat_dot > 0: # going in same direction ++ is_correct_mask[static_ind] = True ++ else: ++ is_wrong = glob_yaw[static_ind] >= 0 if num_neg > num_pos else glob_yaw[static_ind] < 0 ++ is_correct_mask[static_ind] = not is_wrong ++ ++ if np.sum(is_correct_mask) > 0: ++ fix_interp_poses = glob_obj_poses[is_correct_mask] ++ fix_interp_t = obj_ts[is_correct_mask] ++ # what if edges are not correct? ++ if not is_correct_mask[0]: ++ # just pad first correct to beginning ++ fix_interp_poses = np.concatenate([fix_interp_poses[0:1], fix_interp_poses], axis=0) ++ fix_interp_t = np.concatenate([[obj_ts[0]], fix_interp_t], axis=0) ++ if not is_correct_mask[-1]: ++ # just pad first correct to beginning ++ fix_interp_poses = np.concatenate([fix_interp_poses, fix_interp_poses[-1:]], axis=0) ++ fix_interp_t = np.concatenate([fix_interp_t, [obj_ts[-1]]], axis=0) ++ ++ # now interpolate between correct frames ++ fix_flip_interp = PoseInterpolator(fix_interp_poses, fix_interp_t) ++ fixed_rot_poses = fix_flip_interp.interpolate_to_timestamps(obj_ts) ++ glob_obj_poses[:,:3,:3] = fixed_rot_poses[:,:3,:3] # don't want to update translation ++ ++ # after processing interpolate to the relevant ego timestamps (upsample from 10Hz to 30Hz) ++ obj_interp = PoseInterpolator(glob_obj_poses, obj_ts) ++ overlap_ego_mask = np.logical_and(ego_pose_timestamps >= obj_ts[0], ego_pose_timestamps <= obj_ts[-1]) ++ overlap_ego_t = ego_pose_timestamps[overlap_ego_mask] ++ glob_obj_poses = obj_interp.interpolate_to_timestamps(overlap_ego_t) ++ ++ # update full seq information ++ full_obj_traj[overlap_ego_mask] = glob_obj_poses ++ obj_valid[overlap_ego_mask] = True ++ ++ if np.sum(obj_valid) > 0: ++ traj_poses.append(full_obj_traj) ++ traj_valid.append(obj_valid) ++ traj_lwh.append(obj_lwh) ++ traj_ids.append(track_id) ++ ++ traj_poses = np.stack(traj_poses, axis=0) ++ traj_valid = np.stack(traj_valid, axis=0) ++ traj_lwh = np.stack(traj_lwh, axis=0) ++ traj_ids = np.array(traj_ids) ++ ++ if crop2valid: ++ # we interpolated inside ego timestamp maximum, so have to crop a bit ++ val_inds = np.nonzero(np.sum(traj_valid, axis=0) > 0)[0] ++ start_valid = np.amin(val_inds) ++ end_valid = np.amax(val_inds)+1 ++ traj_poses = traj_poses[:,start_valid:end_valid] ++ traj_valid = traj_valid[:,start_valid:end_valid] ++ ego_poses = ego_poses[start_valid:end_valid] ++ ego_pose_timestamps = ego_pose_timestamps[start_valid:end_valid] ++ ++ print(start_valid) ++ print(end_valid) ++ ++ if fill_first_n is not None: ++ # for each trajectory, make sure the first n steps are ++ # all valid either by interpolation or extrapolation. ++ all_ts = ego_pose_timestamps*1e-6 ++ for ai in range(traj_poses.shape[0]): ++ cur_poses = traj_poses[ai] ++ cur_trans = cur_poses[:,:3,3] ++ cur_R = cur_poses[:,:3,:3] ++ cur_valid = traj_valid[ai] ++ ++ if np.sum(cur_valid) < 30: ++ # if does't show up for at least a second throughout, don't be extrapolating ++ continue ++ ++ first_n_valid = cur_valid[:fill_first_n] ++ if np.sum(~first_n_valid) == 0: ++ continue ++ ++ first_n_timestamps = all_ts[:fill_first_n] ++ ++ all_val_inds = sorted(np.nonzero(cur_valid)[0]) ++ first_val_idx = all_val_inds[0] ++ last_val_idx = all_val_inds[-1] ++ ++ # interp steps are those between the first and last valid steps ++ first_n_steps = np.arange(min(fill_first_n, all_ts.shape[0])) ++ interp_steps = np.logical_and(first_n_steps >= first_val_idx, first_n_steps <= last_val_idx) ++ first_n_interp = None ++ if np.sum(interp_steps) > 0: ++ first_n_interp = PoseInterpolator(cur_poses[cur_valid], all_ts[cur_valid]) ++ first_n_interp = first_n_interp.interpolate_to_timestamps(first_n_timestamps[interp_steps]) ++ ++ # extrap fw are those past last valid step (disappear) ++ extrap_fw_steps = first_n_steps > last_val_idx ++ first_n_extrap_fw = None ++ if np.sum(extrap_fw_steps) > 0: ++ if last_val_idx > 0 and cur_valid[last_val_idx-1]: ++ # need to compute velocity to extrapolate ++ dt = first_n_timestamps[extrap_fw_steps] - all_ts[last_val_idx] ++ dt = dt[:,np.newaxis] ++ last_pos = np.repeat(cur_trans[last_val_idx][np.newaxis], dt.shape[0], axis=0) ++ # translation ++ last_lin_vel = (cur_trans[last_val_idx] - cur_trans[last_val_idx-1]) / (all_ts[last_val_idx] - all_ts[last_val_idx-1]) ++ last_lin_vel = np.repeat(last_lin_vel[np.newaxis], dt.shape[0], axis=0) ++ extrap_trans = last_pos + last_lin_vel*dt ++ # copy rotation ++ extrap_R = np.repeat(cur_R[last_val_idx:last_val_idx+1], dt.shape[0], axis=0) ++ # # extrapolate rotation ++ # last_delta_R = np.dot(cur_R[last_val_idx], cur_R[last_val_idx-1].T) ++ # last_delta_rotvec = R.from_matrix(last_delta_R).as_rotvec() ++ # last_delta_angle = np.linalg.norm(last_delta_rotvec) ++ # last_delta_axis = last_delta_rotvec / (last_delta_angle + 1e-9) ++ # last_ang_vel = last_delta_angle / (all_ts[last_val_idx] - all_ts[last_val_idx-1]) ++ # last_ang_vel = np.repeat(last_ang_vel[np.newaxis,np.newaxis], dt.shape[0], axis=0) ++ # extrap_angle = last_ang_vel*dt ++ # last_delta_axis = np.repeat(last_delta_axis[np.newaxis], dt.shape[0], axis=0) ++ # extrap_rotvec = extrap_angle*last_delta_axis ++ # extrap_delta_R = R.from_rotvec(extrap_rotvec).as_matrix() ++ # extrap_R = np.matmul(extrap_delta_R, cur_R[last_val_idx:last_val_idx+1]) ++ ++ # put together ++ extrap_poses = np.repeat(np.eye(4)[np.newaxis], dt.shape[0], axis=0) ++ extrap_poses[:,:3,:3] = extrap_R ++ extrap_poses[:,:3,3] = extrap_trans ++ first_n_extrap_fw = extrap_poses ++ ++ # extrap bw are those before first valid step (appear) ++ extrap_bw_steps = first_n_steps < first_val_idx ++ first_n_extrap_bw = None ++ if np.sum(extrap_bw_steps) > 0: ++ if first_val_idx < (cur_valid.shape[0]-1) and cur_valid[first_val_idx+1]: ++ # need to compute velocity to extrapolate ++ dt = first_n_timestamps[extrap_bw_steps] - all_ts[first_val_idx] # note will be < 0 ++ dt = dt[:,np.newaxis] ++ first_pos = np.repeat(cur_trans[first_val_idx][np.newaxis], dt.shape[0], axis=0) ++ # translation ++ first_lin_vel = (cur_trans[first_val_idx+1] - cur_trans[first_val_idx]) / (all_ts[first_val_idx+1] - all_ts[first_val_idx]) ++ first_lin_vel = np.repeat(first_lin_vel[np.newaxis], dt.shape[0], axis=0) ++ extrap_trans = first_pos + first_lin_vel*dt ++ # copy rotation ++ extrap_R = np.repeat(cur_R[first_val_idx:first_val_idx+1], dt.shape[0], axis=0) ++ # # extrapolate rotation ++ # first_delta_R = np.dot(cur_R[first_val_idx+1], cur_R[first_val_idx].T) ++ # first_delta_rotvec = R.from_matrix(first_delta_R).as_rotvec() ++ # first_delta_angle = np.linalg.norm(first_delta_rotvec) ++ # first_delta_axis = first_delta_rotvec / (first_delta_angle+1e-9) ++ # first_ang_vel = first_delta_angle / (all_ts[first_val_idx+1] - all_ts[first_val_idx]) ++ # first_ang_vel = np.repeat(first_ang_vel[np.newaxis,np.newaxis], dt.shape[0], axis=0) ++ # extrap_angle = first_ang_vel*dt ++ # first_delta_axis = np.repeat(first_delta_axis[np.newaxis], dt.shape[0], axis=0) ++ # extrap_rotvec = extrap_angle*first_delta_axis ++ # extrap_delta_R = R.from_rotvec(extrap_rotvec).as_matrix() ++ # extrap_R = np.matmul(extrap_delta_R, cur_R[first_val_idx:first_val_idx+1]) ++ # put together ++ extrap_poses = np.repeat(np.eye(4)[np.newaxis], dt.shape[0], axis=0) ++ extrap_poses[:,:3,:3] = extrap_R ++ extrap_poses[:,:3,3] = extrap_trans ++ first_n_extrap_bw = extrap_poses ++ ++ first_n_poses = cur_poses[:fill_first_n] ++ if first_n_interp is not None: ++ first_n_poses[interp_steps] = first_n_interp ++ if first_n_extrap_fw is not None: ++ first_n_poses[extrap_fw_steps] = first_n_extrap_fw ++ if first_n_extrap_bw is not None: ++ first_n_poses[extrap_bw_steps] = first_n_extrap_bw ++ ++ first_n_valid = np.sum(np.isnan(first_n_poses.reshape((first_n_poses.shape[0], 16))), axis=1) == 0 ++ ++ traj_poses[ai, :fill_first_n] = first_n_poses ++ traj_valid[ai, :fill_first_n] = first_n_valid ++ ++ # if traj_ids[ai] == 2588: ++ # print(extrap_fw_steps) ++ # print(first_n_extrap_fw) ++ # exit() ++ ++ # based on fill-in can tell if there are duplicated (mis-associated tracks) if they collide ++ if mine_dups: ++ print('Mining possible duplicates...') ++ first_n_xy = traj_poses[:, :fill_first_n, :2, 3] ++ first_n_hvec = traj_poses[:, :fill_first_n, :2, 0] ++ first_n_hvec = first_n_hvec / np.linalg.norm(first_n_hvec, axis=-1, keepdims=True) ++ for ai in range(traj_poses.shape[0]): ++ ai_mask = np.zeros((traj_poses.shape[0]), dtype=bool) ++ ai_mask[ai] = True ++ cur_id = traj_ids[ai] ++ other_ids = traj_ids[~ai_mask] ++ ++ traj_tgt = np.concatenate([first_n_xy[ai], first_n_hvec[ai]], axis=1) ++ lw_tgt = traj_lwh[ai, :2] ++ traj_others = np.concatenate([first_n_xy[~ai_mask], first_n_hvec[~ai_mask]], axis=2) ++ lw_others = traj_lwh[~ai_mask, :2] ++ ++ # if they collide more than 75% of the time ++ veh_coll = check_single_veh_coll(traj_tgt, lw_tgt, traj_others, lw_others) ++ dup_mask = np.sum(veh_coll, axis=1) > int(0.75*veh_coll.shape[1]) ++ ++ if np.sum(dup_mask) > 0: ++ dup_ids = sorted([otid for otid in other_ids[dup_mask].tolist() if otid > cur_id]) ++ dup_ids = [str(otid) for otid in dup_ids] ++ if len(dup_ids) > 0: ++ dup_str = ','.join(dup_ids) ++ dup_str = str(cur_id) + ' : [' + dup_str + '],' ++ print(dup_str) ++ ++ # add ego at index 0 ++ ego_poses = ego_poses[np.newaxis] ++ traj_poses = np.concatenate([ego_poses, traj_poses], axis=0) ++ traj_valid = np.concatenate([np.ones((1, ego_poses.shape[1]), dtype=bool), traj_valid], axis=0) ++ traj_lwh = np.concatenate([np.array([NV_EGO_LWH]), traj_lwh], axis=0) ++ traj_ids = np.array(['ego'] + traj_ids.tolist()) ++ ++ if map_tile is not None: ++ traj_poses[:, :, :2, -1] += map_tile.trans_offset ++ ++ return traj_poses, traj_valid, traj_lwh, ego_pose_timestamps*1e-6, traj_ids ++ ++ ++if __name__ == '__main__': ++ tile_path = './data/nvidia/nvmaps/92d651e5-21d2-4816-b16d-0feace622aa1/jsv3/92d651e5-21d2-4816-b16d-0feace622aa1/tile/4bd02829-cab6-435d-8ebd-679c96787f8b_json' ++ tile = load_tile(tile_path, layers=['lane_dividers_v1', 'lane_channels_v1']) ++ # convert to nuscenes-like map format if desired ++ nusc_tile = convert_tile_to_nuscenes(tile) ++ ++ # dynamic_obj_path = './data/nvidia/dynamic_object_poses.npz' ++ # dyn_obj_data = np.load(dynamic_obj_path) ++ # dyn_obj_poses = dyn_obj_data['poses'] ++ # dyn_obj_timestamps = dyn_obj_data['pose_timestamps'] ++ ++ # print(dyn_obj_poses.shape) ++ ++ # # debug_viz_vid(tile.segments, 'dyn_obj', ++ # # comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ # # poses=dyn_obj_poses[np.newaxis], ++ # # poses_valid=np.ones((1, dyn_obj_poses.shape[0])), ++ # # poses_lwh=np.array([[5.087, 2.307, 1.856]]), ++ # # # pose_ids=traj_ids, ++ # # subsamp=1, ++ # # fps=30) ++ ++ # # TODO option to just load specific ids? ++ ++ # # rivermark ++ # autolabels_path = './data/nvidia/endeavor/labels/autolabels.pkl' ++ # ego_images_path = './data/nvidia/ego_session/processed/44093/images/image_00' ++ # lidar_rig_path = './data/nvidia/endeavor/poses/T_lidar_rig.npz' ++ # frame_range = (510, 880) ++ # traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ # frame_range=frame_range, ++ # postprocess=True, ++ # # extra_obj_path=dynamic_obj_path, ++ # # load_ids=['ego', 'extra'], ++ # crop2valid=False, ++ # fill_first_n=160) ++ ++ # print(traj_poses.shape) ++ # debug_viz_vid(tile.segments, 'rivermark_fill', ++ # comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # pose_ids=traj_ids, ++ # subsamp=1, ++ # fps=30) ++ ++ # exit() ++ ++ ++ autolabels_path = './data/nvidia/endeavor/labels/autolabels.pkl' ++ ego_images_path = './data/nvidia/endeavor/images/image_00' ++ lidar_rig_path = './data/nvidia/endeavor/poses/T_lidar_rig.npz' ++ ++ # this just filters which frames to process, can be None ++ # frame_range = (3170, 3590) # (3200, 3545) # (0, 900), (1000, 2400), None ++ # frame_range = (60, 3590) ++ frame_range = (3000, 3590) ++ # frame_range = (2000, 2400) ++ # frame_range = (570, 870) # merge ++ # frame_range = (240, 660) # merge ++ # frame_range = (1440, 1860) # merge ++ ++ # process on the original nvmap ++ traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ frame_range=frame_range, ++ postprocess=True, ++ mine_dups=False, ++ fill_first_n=300) ++ print(traj_poses.shape) ++ debug_viz_vid(tile.segments, 'extrap_proc_ids_sframe_00003000_eframe_00003300', ++ comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ poses=traj_poses, ++ poses_valid=traj_valid, ++ poses_lwh=traj_lwh, ++ pose_ids=traj_ids, ++ subsamp=1, ++ fps=30) ++ ++ # # process with the nuscenes map version ++ # # (the only difference is the coordinate system is offset such that the origin is at bottom left) ++ # traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ # frame_range=frame_range, ++ # map_tile=nusc_tile, ++ # postprocess=True) ++ # debug_viz_vid(nusc_tile.segments, 'processed_sframe_00003200_eframe_00003545_nusc', ++ # comp_out_path='./out/gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # pose_ids=traj_ids, ++ # subsamp=3, ++ # fps=10) ++ ++ # ++ # output single pose for Amlan ++ # ++ ++ # single_frame_idx = 3500 ++ # last_pose = traj_poses[:,single_frame_idx:single_frame_idx+1] ++ # last_valid = traj_valid[:,single_frame_idx] ++ # last_step_valid_poses = last_pose[last_valid] ++ # print(last_step_valid_poses.shape) ++ # last_step_lwh = traj_lwh[last_valid] ++ ++ # debug_viz_segs(nusc_tile.segments, 'endeavor_nusc_step%d' % (single_frame_idx), poses=last_step_valid_poses, poses_valid=traj_valid[:,single_frame_idx:single_frame_idx+1][last_valid], poses_lwh=last_step_lwh) ++ ++ # np.savez('./out/dev_nvmap/endeavor_poses_frame%06d.npz' % (single_frame_idx), ++ # poses=last_step_valid_poses, ++ # lwh=last_step_lwh) ++ # exit() ++ ++ # ++ # Output GPS trajectories for Jeremy ++ # ++ ++ # # convert ego poses back to global ECEF coordinate system ++ # N, T, _, _ = traj_poses.shape ++ # ecef_traj_poses = np.matmul(GLOBAL_BASE_POSE[np.newaxis,np.newaxis], traj_poses) ++ # print(ecef_traj_poses.shape) ++ # # convert to GPS ++ # gps_traj_poses = ecef_2_lat_lng_alt(ecef_traj_poses.reshape((N*T, 4, 4)), earth_model='WGS84') ++ # lat_lng_alt, orientation_axis, orientation_angle = gps_traj_poses ++ ++ # save_path = os.path.join('./out/dev_nvmap/endeavor_trajectory_track_ids.npz') ++ # out_dict = { ++ # 'timestamps' : traj_t, ++ # 'track_ids' : traj_ids, ++ # 'ecef_poses' : ecef_traj_poses, ++ # 'pose_valid' : traj_valid, ++ # 'gps_lat_lng_alt' : lat_lng_alt.reshape((N, T, 3)), ++ # 'gps_orientation_axis' : orientation_axis.reshape((N, T, 3)), ++ # 'gps_orientation_angle_degrees' : orientation_angle.reshape((N, T, 1)) ++ # } ++ # for k, v in out_dict.items(): ++ # print(k) ++ # print(v.shape) ++ # # if k != 'timestamps': ++ # # print(np.sum(np.isnan(v), axis=1)) ++ # np.savez(save_path, **out_dict) ++ ++ # exit() ++ ++ # debug_viz_vid(tile.segments, 'gt_sframe_00003200_eframe_00003545', ++ # comp_out_path='./out/gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # subsamp=3, ++ # fps=10) +diff --git a/src/trajdata/dataset_specific/drivesim/test.py b/src/trajdata/dataset_specific/drivesim/test.py +new file mode 100644 +index 0000000..ee6b58e +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/test.py +@@ -0,0 +1,79 @@ ++import numpy as np ++from typing import Dict ++import trajdata.dataset_specific.drivesim.nvmap_utils as nvutils ++ ++# This is from the tag from the .xodr file for Endeavor. ++LATLONALT_ORIGIN_ENDEAVOR = np.array([[37.37852062996696, -121.9596180846297, 0.0]]) ++ ++ ++def convert_to_DS(poses: np.ndarray, track_ids:list,fps:float): ++ """_summary_ ++ ++ Args: ++ poses (np.ndarray): x, y, h states of each agent at each time, relative to the ego vehicle's state at time t=0. (N, T, 3) ++ """ ++ # wgs84 = coutils.LocalToWGS84((0, 0, 0), (37.37852062996696, -121.9596180846297, 0.0)) ++ # wgs84_2 = coutils.ECEFtoWGS84(nvutils.GLOBAL_BASE_POSE_ENDEAVOR[:3, -1]) ++ ++ world_from_map_ft = nvutils.lat_lng_alt_2_ecef( ++ LATLONALT_ORIGIN_ENDEAVOR, ++ np.array([[1, 0, 0]]), ++ np.array([[0]]) ++ ) ++ ++ N, T = poses.shape[:2] ++ x = poses[..., 0] ++ y = poses[..., 1] ++ heading = poses[..., 2] ++ ++ c = np.cos(heading) ++ s = np.sin(heading) ++ T_mat = np.tile(np.eye(4), (N, T, 1, 1)) ++ T_mat[..., 0, 0] = c ++ T_mat[..., 0, 1] = -s ++ T_mat[..., 1, 0] = s ++ T_mat[..., 1, 1] = c ++ T_mat[..., 0, 3] = x ++ T_mat[..., 1, 3] = y ++ # TODO: Some height for ray-casting down to road? ++ # T_mat[..., 2, 3] = 0 ++ ++ ecef_traj_poses = np.matmul(world_from_map_ft[:, np.newaxis], T_mat) ++ gps_traj_poses = nvutils.ecef_2_lat_lng_alt(ecef_traj_poses.reshape((N*T, 4, 4)), earth_model='WGS84') ++ lat_lng_alt, orientation_axis, orientation_angle = gps_traj_poses ++ breakpoint() ++ out_dict = { ++ 'timestamps' : np.linspace(0, T/fps, T), ++ 'track_ids' : np.array(track_ids), ++ # 'bbox_lwh' : np.array([[4.387, 1.907, 1.656], [4.387, 1.907, 1.656]]), ++ # 'ecef_poses' : ecef_traj_poses, ++ 'pose_valid' : np.ones((ecef_traj_poses.shape[0], ecef_traj_poses.shape[1]), dtype=bool), ++ 'gps_lat_lng_alt' : lat_lng_alt.reshape((N, T, 3)), ++ 'gps_orientation_axis' : orientation_axis.reshape((N, T, 3)), ++ 'gps_orientation_angle_degrees' : orientation_angle.reshape((N, T, 1)) ++ } ++ ++ return out_dict ++ ++ ++def main(): ++ poses = np.zeros((2, 315, 3)) ++ poses[0, :, 0] = np.linspace(-505, -519, 315) ++ poses[0, :, 1] = np.linspace(-1019, -866, 315) ++ poses[0, :, 2] = np.linspace(np.pi/2, 5*np.pi/8, 315) ++ ++ fps = 30 ++ out_dict = convert_to_DS(poses,["ego","769"],fps) ++ ++ # with np.load("/home/bivanovic/projects/drivesim-ov/source/extensions/omni.drivesim.dl_traffic_model/data/example_trajectories.npz") as data: ++ # out_dict["gps_lat_lng_alt"][1] = data["gps_lat_lng_alt"][1] ++ # out_dict["gps_orientation_axis"][1] = data["gps_orientation_axis"][1] ++ # out_dict["gps_orientation_angle_degrees"][1] = data["gps_orientation_angle_degrees"][1] ++ ++ np.savez( ++ "/home/bivanovic/projects/drivesim-ov/source/extensions/omni.drivesim.dl_traffic_model/data/test.npz", ++ **out_dict ++ ) ++ ++if __name__ == "__main__": ++ main() +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +index 1c4df3f..08f9977 100644 +--- a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py ++++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +@@ -360,38 +360,6 @@ class NuplanDataset(RawDataset): + + return agent_list, agent_presence + +- def cache_map( +- self, +- map_name: str, +- cache_path: Path, +- map_cache_class: Type[SceneCache], +- map_params: Dict[str, Any], +- ) -> None: +- nuplan_map: NuPlanMap = map_factory.get_maps_api( +- map_root=str(self.metadata.data_dir.parent / "maps"), +- map_version=nuplan_utils.NUPLAN_MAP_VERSION, +- map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], +- ) +- +- # Loading all layer geometries. +- nuplan_map.initialize_all_layers() +- +- # This df has the normal lane_connectors with additional boundary information, +- # which we want to use, however the default index is not the lane_connector_fid, +- # although it is a 1:1 mapping so we instead create another index with the +- # lane_connector_fids as the key and the resulting integer indices as the value. +- lane_connector_fids: pd.Series = nuplan_map._vector_map[ +- "gen_lane_connectors_scaled_width_polygons" +- ]["lane_connector_fid"] +- lane_connector_idxs: pd.Series = pd.Series( +- index=lane_connector_fids, data=range(len(lane_connector_fids)) +- ) +- +- vector_map = VectorMap(map_id=f"{self.name}:{map_name}") +- nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) +- +- map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) +- + def cache_maps( + self, + cache_path: Path, +@@ -406,4 +374,45 @@ class NuplanDataset(RawDataset): + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): +- self.cache_map(map_name, cache_path, map_cache_class, map_params) ++ cache_map( ++ map_root=str(self.metadata.data_dir.parent / "maps"), ++ env_name=self.name, ++ map_name=map_name, ++ cache_path=cache_path, ++ map_cache_class=map_cache_class, ++ map_params=map_params, ++ ) ++ ++ ++def cache_map( ++ map_root: str, ++ env_name: str, ++ map_name: str, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++) -> None: ++ nuplan_map: NuPlanMap = map_factory.get_maps_api( ++ map_root=map_root, ++ map_version=nuplan_utils.NUPLAN_MAP_VERSION, ++ map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], ++ ) ++ ++ # Loading all layer geometries. ++ nuplan_map.initialize_all_layers() ++ ++ # This df has the normal lane_connectors with additional boundary information, ++ # which we want to use, however the default index is not the lane_connector_fid, ++ # although it is a 1:1 mapping so we instead create another index with the ++ # lane_connector_fids as the key and the resulting integer indices as the value. ++ lane_connector_fids: pd.Series = nuplan_map._vector_map[ ++ "gen_lane_connectors_scaled_width_polygons" ++ ]["lane_connector_fid"] ++ lane_connector_idxs: pd.Series = pd.Series( ++ index=lane_connector_fids, data=range(len(lane_connector_fids)) ++ ) ++ ++ vector_map = VectorMap(map_id=f"{env_name}:{map_name}") ++ nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) ++ ++ map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py +index 5007ace..dea194c 100644 +--- a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py ++++ b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py +@@ -189,6 +189,7 @@ class NuPlanObject: + + + def nuplan_type_to_unified_type(nuplan_type: str) -> AgentType: ++ # TODO (pkarkus) map traffic cones, barriers to static; generic_object to pedestrian + if nuplan_type == "pedestrian": + return AgentType.PEDESTRIAN + elif nuplan_type == "bicycle": +@@ -327,12 +328,20 @@ def populate_vector_map( + # The right boundary of Lane A has Lane A to its left. + boundary_connectivity_dict[right_boundary_id]["left"].append(lane_id) + ++ # Find road areas that this lane intersects for faster lane-based lookup later. ++ intersect_filt = nuplan_map._vector_map["drivable_area"].intersects(lane_info["geometry"]) ++ isnear_filt = (nuplan_map._vector_map["drivable_area"].distance(lane_info["geometry"]) < 3.) ++ road_area_ids = set(nuplan_map._vector_map["drivable_area"][intersect_filt | isnear_filt]["fid"].values) ++ if not road_area_ids: ++ print (f"Warning: no road lane associated with lane {lane_id}") ++ + # "partial" because we aren't adding lane connectivity until later. + partial_new_lane = RoadLane( + id=lane_id, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), ++ road_area_ids=road_area_ids, + ) + vector_map.add_map_element(partial_new_lane) + overall_pbar.update() +diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py +index 68bd5b9..0c6f29b 100644 +--- a/src/trajdata/dataset_specific/scene_records.py ++++ b/src/trajdata/dataset_specific/scene_records.py +@@ -28,6 +28,12 @@ class NuscSceneRecord(NamedTuple): + desc: str + data_idx: int + ++class CarlaSceneRecord(NamedTuple): ++ name: str ++ location: str ++ length: str ++ data_idx: int ++ + + class LyftSceneRecord(NamedTuple): + name: str +@@ -48,3 +54,11 @@ class NuPlanSceneRecord(NamedTuple): + split: str + # desc: str + data_idx: int ++ ++class DrivesimSceneRecord(NamedTuple): ++ name: str ++ location: str ++ length: str ++ split: str ++ # desc: str ++ data_idx: int +diff --git a/src/trajdata/maps/map_api.py b/src/trajdata/maps/map_api.py +index 36a1245..e6e4a6d 100644 +--- a/src/trajdata/maps/map_api.py ++++ b/src/trajdata/maps/map_api.py +@@ -1,6 +1,6 @@ + from __future__ import annotations + +-from typing import TYPE_CHECKING, Optional ++from typing import TYPE_CHECKING, Optional, Union + + if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree +@@ -15,9 +15,20 @@ from trajdata.utils import map_utils + + + class MapAPI: +- def __init__(self, unified_cache_path: Path) -> None: ++ def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> None: ++ """A simple interface for loading trajdata's vector maps which does not require ++ instantiation of a `UnifiedDataset` object. ++ ++ Args: ++ unified_cache_path (Path): Path to trajdata's local cache on disk. ++ keep_in_memory (bool): Whether loaded maps should be stored ++ in memory (memoized) for later re-use. For most cases (e.g., batched dataloading), ++ this is a good idea. However, this can cause rapid memory usage growth for some ++ datasets (e.g., Waymo) and it can be better to disable this. Defaults to False. ++ """ + self.unified_cache_path: Path = unified_cache_path + self.maps: Dict[str, VectorMap] = dict() ++ self._keep_in_memory = keep_in_memory + + def get_map( + self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs +@@ -25,22 +36,43 @@ class MapAPI: + if map_id not in self.maps: + env_name, map_name = map_id.split(":") + env_maps_path: Path = self.unified_cache_path / env_name / "maps" +- stored_vec_map: VectorizedMap = map_utils.load_vector_map( +- env_maps_path / f"{map_name}.pb" +- ) ++ vec_map_path: Path = env_maps_path / f"{map_name}.pb" ++ ++ if not Path.exists(vec_map_path): ++ if self.data_dirs is None: ++ raise ValueError( ++ f"There is no cached map at {vec_map_path} and there was no " + ++ "`data_dirs` provided to rebuild cache.") ++ ++ # Rebuild maps by creating a dummy dataset object. ++ # TODO(pkarkus) We need support for rebuilding map files only, without creating dataset and building agent data. ++ from trajdata.dataset import UnifiedDataset ++ dataset = UnifiedDataset( ++ desired_data=[env_name], ++ rebuild_maps=True, ++ data_dirs=self.data_dirs, ++ cache_location=self.unified_cache_path, ++ verbose=True, ++ ) ++ # Hopefully we successfully created map cache. ++ ++ stored_vec_map: VectorizedMap = map_utils.load_vector_map(vec_map_path) + + vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) + vec_map.search_kdtrees: Dict[ + str, MapElementKDTree + ] = map_utils.load_kdtrees(env_maps_path / f"{map_name}_kdtrees.dill") + +- self.maps[map_id] = vec_map ++ if self._keep_in_memory: ++ self.maps[map_id] = vec_map ++ else: ++ vec_map = self.maps[map_id] + + if scene_cache is not None: +- self.maps[map_id].associate_scene_data( ++ vec_map.associate_scene_data( + scene_cache.get_traffic_light_status_dict( + kwargs.get("desired_dt", None) + ) + ) + +- return self.maps[map_id] ++ return vec_map +diff --git a/src/trajdata/maps/map_kdtree.py b/src/trajdata/maps/map_kdtree.py +index 59abdc2..a978a90 100644 +--- a/src/trajdata/maps/map_kdtree.py ++++ b/src/trajdata/maps/map_kdtree.py +@@ -1,7 +1,7 @@ + from __future__ import annotations + + from collections import defaultdict +-from typing import TYPE_CHECKING ++from typing import TYPE_CHECKING, Iterator + + if TYPE_CHECKING: + from trajdata.maps.vec_map import VectorMap +@@ -43,8 +43,7 @@ class MapElementKDTree: + total=len(vector_map), + disable=not verbose, + ): +- result = self._extract_points(map_elem) +- if result is not None: ++ for result in self._extract_points_and_metadata(map_elem): + points, extras = result + polyline_inds.extend([len(polylines)] * points.shape[0]) + +@@ -54,6 +53,7 @@ class MapElementKDTree: + + for k, v in extras.items(): + metadata[k].append(v) ++ metadata["map_elem_id"].append(np.array([map_elem.id])) + + points = np.concatenate(polylines, axis=0) + polyline_inds = np.array(polyline_inds) +@@ -64,7 +64,7 @@ class MapElementKDTree: + + def _extract_points_and_metadata( + self, map_element: MapElement +- ) -> Optional[Tuple[np.ndarray, dict[str, np.ndarray]]]: ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + """Defines the coordinates we want to store in the KDTree for a MapElement. + Args: + map_element (MapElement): the MapElement to store in the KDTree. +@@ -116,16 +116,16 @@ class LaneCenterKDTree(MapElementKDTree): + self.max_segment_len = max_segment_len + super().__init__(vector_map) + +- def _extract_points(self, map_element: MapElement) -> Optional[np.ndarray]: ++ def _extract_points_and_metadata( ++ self, map_element: MapElement ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + if map_element.elem_type == MapElementType.ROAD_LANE: + pts: Polyline = map_element.center + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + + # We only want to store xyz in the kdtree, not heading. +- return pts.xyz, {"heading": pts.h} +- else: +- return None ++ yield pts.xyz, {"heading": pts.h} + + def current_lane_inds( + self, +@@ -181,3 +181,42 @@ class LaneCenterKDTree(MapElementKDTree): + min_costs = [np.min(costs[lane_inds == ind]) for ind in unique_lane_inds] + + return unique_lane_inds[np.argsort(min_costs)] ++ ++ ++class RoadAreaKDTree(MapElementKDTree): ++ """KDTree for road area polygons. ++ The polygons may have holes. We will simply store points along both the ++ exterior_polygon and all interior_holes. Finding a nearest point in this KDTree will ++ correspond to finding any ++ """ ++ ++ def __init__( ++ self, vector_map: VectorMap, max_segment_len: Optional[float] = None ++ ) -> None: ++ """ ++ Args: ++ vec_map: the VectorizedMap object to build the KDTree for ++ max_segment_len (float, optional): if specified, we will insert extra points into the KDTree ++ such that all polyline segments are shorter then max_segment_len. ++ """ ++ self.max_segment_len = max_segment_len ++ super().__init__(vector_map) ++ ++ def _extract_points_and_metadata( ++ self, map_element: MapElement ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: ++ if map_element.elem_type == MapElementType.ROAD_AREA: ++ # Exterior polygon ++ pts: Polyline = map_element.exterior_polygon ++ if self.max_segment_len is not None: ++ pts = pts.interpolate(max_dist=self.max_segment_len) ++ # We only want to store xyz in the kdtree, not heading. ++ yield pts.xyz, {"exterior": np.array([True])} ++ ++ # Interior holes ++ for pts in map_element.interior_holes: ++ if self.max_segment_len is not None: ++ pts = pts.interpolate(max_dist=self.max_segment_len) ++ # We only want to store xyz in the kdtree, not heading. ++ yield pts.xyz, {"exterior": np.array([False])} ++ +diff --git a/src/trajdata/maps/vec_map.py b/src/trajdata/maps/vec_map.py +index 95b299d..52a225f 100644 +--- a/src/trajdata/maps/vec_map.py ++++ b/src/trajdata/maps/vec_map.py +@@ -25,9 +25,10 @@ import matplotlib.pyplot as plt + import numpy as np + from matplotlib.axes import Axes + from tqdm import tqdm ++from shapely.geometry import Polygon + + import trajdata.proto.vectorized_map_pb2 as map_proto +-from trajdata.maps.map_kdtree import LaneCenterKDTree ++from trajdata.maps.map_kdtree import LaneCenterKDTree, RoadAreaKDTree + from trajdata.maps.traffic_light_status import TrafficLightStatus + from trajdata.maps.vec_map_elements import ( + MapElement, +@@ -52,6 +53,7 @@ class VectorMap: + ) + search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None + traffic_light_status: Optional[Dict[Tuple[int, int], TrafficLightStatus]] = None ++ online_metadict: Optional[Dict[Tuple[str, int], Dict]] = None + + def __post_init__(self) -> None: + self.env_name, self.map_name = self.map_id.split(":") +@@ -60,11 +62,16 @@ class VectorMap: + if MapElementType.ROAD_LANE in self.elements: + self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) + ++ # self._road_area_polygons: Dict[str, Polygon] = {} ++ + def add_map_element(self, map_elem: MapElement) -> None: + self.elements[map_elem.elem_type][map_elem.id] = map_elem + + def compute_search_indices(self) -> None: +- self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} ++ self.search_kdtrees = { ++ MapElementType.ROAD_LANE: LaneCenterKDTree(self), ++ # MapElementType.ROAD_AREA: RoadAreaKDTree(self), ++ } + + def iter_elems(self) -> Iterator[MapElement]: + for elems_dict in self.elements.values(): +@@ -101,6 +108,9 @@ class VectorMap: + new_lane.adjacent_lanes_right.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_right] + ) ++ # new_lane.road_area_ids.extend( ++ # [road_area_id.encode() for road_area_id in road_lane.road_area_ids] ++ # ) + + def _write_road_areas( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray +@@ -250,6 +260,9 @@ class VectorMap: + prev_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.entry_lanes] + ) ++ # road_area_ids: Set[str] = set( ++ # [iden.decode() for iden in road_lane_obj.road_area_ids] ++ # ) + + # Double-using the connectivity attributes for lane IDs now (will + # replace them with Lane objects after all Lane objects have been created). +@@ -262,6 +275,7 @@ class VectorMap: + adj_lanes_right, + next_lanes, + prev_lanes, ++ # road_area_ids=road_area_ids, + ) + map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane + +@@ -369,6 +383,33 @@ class VectorMap: + self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) + ] + ++ def get_road_areas_within(self, xyz: np.ndarray, dist: float) -> List[RoadArea]: ++ road_area_kdtree: RoadAreaKDTree = self.search_kdtrees[MapElementType.ROAD_AREA] ++ polyline_inds = road_area_kdtree.polyline_inds_in_range(xyz, dist) ++ element_ids = set([ ++ road_area_kdtree.metadata["map_elem_id"][ind] for ind in polyline_inds ++ ]) ++ if MapElementType.ROAD_AREA not in self.elements: ++ raise ValueError( ++ "Road areas are not loaded. Use map_api.get_map(..., incl_road_areas=True)." ++ ) ++ return [ ++ self.elements[MapElementType.ROAD_AREA][id] for id in element_ids ++ ] ++ ++ def get_road_area_polygon_2d(self, id: str) -> Polygon: ++ if id not in self._road_area_polygons: ++ road_area: RoadArea = self.elements[MapElementType.ROAD_AREA][id] ++ road_area_polygon = Polygon( ++ shell=[(pt[0], pt[1]) for pt in road_area.exterior_polygon.points], ++ holes=[ ++ [(pt[0], pt[1]) for pt in polyline.points] ++ for polyline in road_area.interior_holes ++ ] ++ ) ++ self._road_area_polygons[id] = road_area_polygon ++ return self._road_area_polygons[id] ++ + def get_traffic_light_status( + self, lane_id: str, scene_ts: int + ) -> TrafficLightStatus: +@@ -380,6 +421,15 @@ class VectorMap: + else TrafficLightStatus.NO_DATA + ) + ++ def get_online_metadict( ++ self, lane_id: str, scene_ts: int = 0 ++ ) -> Dict: ++ return ( ++ self.online_metadict[(str(lane_id), scene_ts)] ++ if self.online_metadict is not None ++ else {} ++ ) ++ + def rasterize( + self, resolution: float = 2, **kwargs + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: +diff --git a/src/trajdata/maps/vec_map_elements.py b/src/trajdata/maps/vec_map_elements.py +index fd7a9f1..d1e34b6 100644 +--- a/src/trajdata/maps/vec_map_elements.py ++++ b/src/trajdata/maps/vec_map_elements.py +@@ -1,10 +1,11 @@ + from dataclasses import dataclass, field + from enum import IntEnum +-from typing import List, Optional, Set ++from typing import List, Optional, Set, Union + + import numpy as np + + from trajdata.utils import map_utils ++from trajdata.utils.arr_utils import angle_wrap + + + class MapElementType(IntEnum): +@@ -47,6 +48,15 @@ class Polyline: + def xyz(self) -> np.ndarray: + return self.points[..., :3] + ++ @property ++ def xyh(self) -> np.ndarray: ++ if self.has_heading: ++ return self.points[..., (0, 1, 3)] ++ else: ++ raise ValueError( ++ f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." ++ ) ++ + @property + def xyzh(self) -> np.ndarray: + if self.has_heading: +@@ -67,14 +77,17 @@ class Polyline: + map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) + ) + +- def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: ++ def project_onto(self, xyz_or_xyzh: np.ndarray, return_index: bool = False) -> Union[np.ndarray, List]: + """Project the given points onto this Polyline. + + Args: + xyzh (np.ndarray): Points to project, of shape (M, D) ++ return_indices (bool): Return the index of starting point of the line segment ++ on which the projected points lies on. + + Returns: + np.ndarray: The projected points, of shape (M, D) ++ np.ndarray: The index of previous polyline points if return_indices == True. + + Note: + D = 4 if this Polyline has headings, otherwise D = 3 +@@ -94,7 +107,8 @@ class Polyline: + dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( + axis=-1, keepdims=True + ) +- norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 ++ # norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 ++ norms: np.ndarray = np.square(line_seg_diffs).sum(axis=-1, keepdims=True) + + # Clip ensures that the projected point stays within the line segment boundaries. + projs: np.ndarray = ( +@@ -102,20 +116,114 @@ class Polyline: + ) + + # 2. Find the nearest projections to the original points. +- closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) ++ # We have nan values when two consecutive points are equal. This will never be ++ # the closest projection point, so we replace nans with a large number. ++ point_to_proj_dist = np.nan_to_num(np.linalg.norm(xyz - projs, axis=-1), nan=1e6) ++ closest_proj_idxs: int = point_to_proj_dist.argmin(axis=-1) ++ ++ proj_points = projs[range(xyz.shape[0]), closest_proj_idxs] + + if self.has_heading: + # Adding in the heading of the corresponding p0 point (which makes + # sense as p0 to p1 is a line => same heading along it). +- return np.concatenate( ++ proj_points = np.concatenate( + [ +- projs[range(xyz.shape[0]), closest_proj_idxs], ++ proj_points, + np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), + ], + axis=-1, + ) ++ ++ if return_index: ++ return proj_points, closest_proj_idxs + else: +- return projs[range(xyz.shape[0]), closest_proj_idxs] ++ return proj_points ++ ++ def distance_to_point(self, xyz: np.ndarray): ++ assert xyz.ndim == 2 ++ xyz_proj = self.project_onto(xyz) ++ return np.linalg.norm(xyz[..., :3] - xyz_proj[..., :3], axis=-1) ++ ++ def get_length(self): ++ # TODO(pkarkus) we could store cummulative distances to speed this up ++ dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) ++ length = dists.sum() ++ return length ++ ++ ++ def get_length_from(self, start_ind: np.ndarray): ++ # TODO(pkarkus) we could store cummulative distances to speed this up ++ assert start_ind.ndim == 1 ++ dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) ++ length_upto = np.cumsum(np.pad(dists, (1, 0))) ++ length_from = length_upto[-1][None] - length_upto[start_ind] ++ return length_from ++ ++ ++ def traverse_along(self, dist: np.ndarray, start_ind: Optional[np.ndarray] = None) -> np.ndarray: ++ """ ++ Interpolated endpoint of traversing `dist` distance along polyline from a starting point. ++ ++ Returns nan if the end point is not inside the polyline. ++ TODO(pkarkus) we could store cummulative distances to speed this up ++ ++ Args: ++ dist (np.ndarray): distances, any shape [...] ++ start_ind (np.ndarray): index of point along polyline to calcualte distance from. ++ Optional. Shape must match dist. [...] ++ ++ Returns: ++ endpoint_xyzh (np.ndarray): points along polyline `dist` distance from the ++ starting point. Nan if endpoint would require extrapolation. [..., 4] ++ ++ """ ++ assert self.has_heading ++ ++ # Add up distances from beginning of polyline ++ segment_lens = np.linalg.norm(self.xyz[1:] - self.xyz[:-1], axis=-1) # n-1 ++ cum_len = np.pad(np.cumsum(segment_lens, axis=0), (1, 0)) # n ++ ++ # Increase dist with the length of lane up to start_ind ++ if start_ind is not None: ++ assert start_ind.ndim == dist.ndim ++ dist = dist + cum_len[start_ind] ++ ++ # Find the first index where cummulative length is larger or equal than `dist` ++ inds = np.searchsorted(cum_len, dist, side='right') ++ # Invalidate inds == 0 and inds == len(cum_len), which means endpoint is outside the polyline. ++ invalid = np.logical_or(inds == 0, inds == len(cum_len)) ++ # Replace invalid indices so we can easily carry out computation below, and invalidate output eventually. ++ inds[invalid] = 1 ++ ++ # Remaining distance from last point ++ remaining_dist = dist - cum_len[inds-1] ++ ++ # Invalidate negative remaining dist (this should only happen when dist < 0) ++ invalid = np.logical_or(invalid, remaining_dist < 0.) ++ ++ # Interpolate between the previous and next points. ++ segment_vect_xyz = self.xyz[inds] - self.xyz[inds-1] ++ segment_len = np.linalg.norm(segment_vect_xyz, axis=-1) ++ assert (segment_len > 0.).all(), "Polyline segment has zero length" ++ ++ proportion = (remaining_dist / segment_len) ++ endpoint_xyz = segment_vect_xyz * proportion[..., np.newaxis] + self.xyz[inds] ++ endpoint_h = angle_wrap(angle_wrap(self.h[inds] - self.h[inds-1]) * proportion + self.h[inds-1]) ++ endpoint_xyzh = np.concatenate((endpoint_xyz, endpoint_h[..., np.newaxis]), axis=-1) ++ ++ # Invalidate dummy output ++ endpoint_xyzh[invalid] = np.nan ++ ++ return endpoint_xyzh ++ ++ def concatenate_with(self, other: "Polyline") -> "Polyline": ++ return self.concatenate([self, other]) ++ ++ @staticmethod ++ def concatenate(polylines: List["Polyline"]) -> "Polyline": ++ # Assumes no overlap between consecutive polylines, i.e. next lane starts after current lane ends. ++ points = np.concatenate([polyline.points for polyline in polylines], axis=0) ++ return Polyline(points) + + + @dataclass +@@ -132,6 +240,7 @@ class RoadLane(MapElement): + adj_lanes_right: Set[str] = field(default_factory=lambda: set()) + next_lanes: Set[str] = field(default_factory=lambda: set()) + prev_lanes: Set[str] = field(default_factory=lambda: set()) ++ road_area_ids: Set[str] = field(default_factory=lambda: set()) + elem_type: MapElementType = MapElementType.ROAD_LANE + + def __post_init__(self) -> None: +@@ -150,7 +259,35 @@ class RoadLane(MapElement): + @property + def reachable_lanes(self) -> Set[str]: + return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes +- ++ ++ ++ def combine_next(self, next_lane): ++ assert next_lane.id in self.next_lanes ++ self.next_lanes.remove(next_lane.id) ++ self.next_lanes = self.next_lanes.union(next_lane.next_lanes) ++ self.center = self.center.concatenate_with(next_lane.center) ++ if self.left_edge is not None and next_lane.left_edge is not None: ++ self.left_edge = self.left_edge.concatenate_with(next_lane.left_edge) ++ if self.right_edge is not None and next_lane.right_edge is not None: ++ self.right_edge = self.right_edge.concatenate_with(next_lane.right_edge) ++ self.adj_lanes_right = self.adj_lanes_right.union(next_lane.adj_lanes_right) ++ self.adj_lanes_left = self.adj_lanes_left.union(next_lane.adj_lanes_left) ++ self.road_area_ids = self.road_area_ids.union(next_lane.road_area_ids) ++ ++ def combine_prev(self,prev_lane): ++ assert prev_lane.id in self.prev_lanes ++ self.prev_lanes.remove(prev_lane.id) ++ self.prev_lanes = self.prev_lanes.union(prev_lane.prev_lanes) ++ self.center = prev_lane.center.concatenate_with(self.center) ++ if self.left_edge is not None and prev_lane.left_edge is not None: ++ self.left_edge = prev_lane.left_edge.concatenate_with(self.left_edge) ++ if self.right_edge is not None and prev_lane.right_edge is not None: ++ self.right_edge = prev_lane.right_edge.concatenate_with(self.right_edge) ++ self.adj_lanes_right = self.adj_lanes_right.union(prev_lane.adj_lanes_right) ++ self.adj_lanes_left = self.adj_lanes_left.union(prev_lane.adj_lanes_left) ++ self.road_area_ids = self.road_area_ids.union(prev_lane.road_area_ids) ++ ++ + + @dataclass + class RoadArea(MapElement): +diff --git a/src/trajdata/maps/vec_map_remote.py b/src/trajdata/maps/vec_map_remote.py +new file mode 100644 +index 0000000..77f2375 +--- /dev/null ++++ b/src/trajdata/maps/vec_map_remote.py +@@ -0,0 +1,582 @@ ++from __future__ import annotations ++ ++from typing import TYPE_CHECKING ++ ++if TYPE_CHECKING: ++ from trajdata.maps.map_kdtree import MapElementKDTree, LaneCenterKDTree ++ ++from collections import defaultdict ++from dataclasses import dataclass, field ++from math import ceil ++from typing import ( ++ DefaultDict, ++ Dict, ++ Iterator, ++ List, ++ Optional, ++ Set, ++ Tuple, ++ Union, ++ overload, ++) ++ ++import matplotlib as mpl ++import matplotlib.pyplot as plt ++import numpy as np ++from matplotlib.axes import Axes ++from tqdm import tqdm ++ ++import trajdata.proto.vectorized_map_pb2 as map_proto ++from trajdata.maps.map_kdtree import LaneCenterKDTree ++from trajdata.maps.traffic_light_status import TrafficLightStatus ++from trajdata.maps.vec_map_elements import ( ++ MapElement, ++ MapElementType, ++ PedCrosswalk, ++ PedWalkway, ++ Polyline, ++ RoadArea, ++ RoadLane, ++) ++from trajdata.utils import map_utils, raster_utils ++ ++ ++@dataclass(repr=False) ++class VectorMap: ++ map_id: str ++ extent: Optional[ ++ np.ndarray ++ ] = None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] ++ elements: DefaultDict[MapElementType, Dict[str, MapElement]] = field( ++ default_factory=lambda: defaultdict(dict) ++ ) ++ search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None ++ traffic_light_status: Optional[Dict[Tuple[int, int], TrafficLightStatus]] = None ++ ++ def __post_init__(self) -> None: ++ self.env_name, self.map_name = self.map_id.split(":") ++ ++ self.lanes: Optional[List[RoadLane]] = None ++ if MapElementType.ROAD_LANE in self.elements: ++ self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) ++ ++ def add_map_element(self, map_elem: MapElement) -> None: ++ self.elements[map_elem.elem_type][map_elem.id] = map_elem ++ ++ def compute_search_indices(self) -> None: ++ self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} ++ ++ def iter_elems(self) -> Iterator[MapElement]: ++ for elems_dict in self.elements.values(): ++ for elem in elems_dict.values(): ++ yield elem ++ ++ def get_road_lane(self, lane_id: str) -> RoadLane: ++ return self.elements[MapElementType.ROAD_LANE][lane_id] ++ ++ def __len__(self) -> int: ++ return sum(len(elems_dict) for elems_dict in self.elements.values()) ++ ++ def _write_road_lanes( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ road_lane: RoadLane ++ for elem_id, road_lane in self.elements[MapElementType.ROAD_LANE].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_lane: map_proto.RoadLane = new_element.road_lane ++ map_utils.populate_lane_polylines(new_lane, road_lane, shifted_origin) ++ ++ new_lane.entry_lanes.extend( ++ [lane_id.encode() for lane_id in road_lane.prev_lanes] ++ ) ++ new_lane.exit_lanes.extend( ++ [lane_id.encode() for lane_id in road_lane.next_lanes] ++ ) ++ ++ new_lane.adjacent_lanes_left.extend( ++ [lane_id.encode() for lane_id in road_lane.adj_lanes_left] ++ ) ++ new_lane.adjacent_lanes_right.extend( ++ [lane_id.encode() for lane_id in road_lane.adj_lanes_right] ++ ) ++ ++ def _write_road_areas( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ road_area: RoadArea ++ for elem_id, road_area in self.elements[MapElementType.ROAD_AREA].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_area: map_proto.RoadArea = new_element.road_area ++ map_utils.populate_polygon( ++ new_area.exterior_polygon, ++ road_area.exterior_polygon.xyz, ++ shifted_origin, ++ ) ++ ++ hole: Polyline ++ for hole in road_area.interior_holes: ++ new_hole: map_proto.Polyline = new_area.interior_holes.add() ++ map_utils.populate_polygon( ++ new_hole, ++ hole.xyz, ++ shifted_origin, ++ ) ++ ++ def _write_ped_crosswalks( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ ped_crosswalk: PedCrosswalk ++ for elem_id, ped_crosswalk in self.elements[ ++ MapElementType.PED_CROSSWALK ++ ].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_crosswalk: map_proto.PedCrosswalk = new_element.ped_crosswalk ++ map_utils.populate_polygon( ++ new_crosswalk.polygon, ++ ped_crosswalk.polygon.xyz, ++ shifted_origin, ++ ) ++ ++ def _write_ped_walkways( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ ped_walkway: PedWalkway ++ for elem_id, ped_walkway in self.elements[MapElementType.PED_WALKWAY].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_walkway: map_proto.PedWalkway = new_element.ped_walkway ++ map_utils.populate_polygon( ++ new_walkway.polygon, ++ ped_walkway.polygon.xyz, ++ shifted_origin, ++ ) ++ ++ def to_proto(self) -> map_proto.VectorizedMap: ++ output_map = map_proto.VectorizedMap() ++ output_map.name = self.map_id ++ ++ ( ++ output_map.min_pt.x, ++ output_map.min_pt.y, ++ output_map.min_pt.z, ++ output_map.max_pt.x, ++ output_map.max_pt.y, ++ output_map.max_pt.z, ++ ) = self.extent ++ ++ shifted_origin: np.ndarray = self.extent[:3] ++ ( ++ output_map.shifted_origin.x, ++ output_map.shifted_origin.y, ++ output_map.shifted_origin.z, ++ ) = shifted_origin ++ ++ # Populating the elements in the vectorized map protobuf. ++ self._write_road_lanes(output_map, shifted_origin) ++ self._write_road_areas(output_map, shifted_origin) ++ self._write_ped_crosswalks(output_map, shifted_origin) ++ self._write_ped_walkways(output_map, shifted_origin) ++ ++ return output_map ++ ++ @classmethod ++ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): ++ # Options for which map elements to include. ++ incl_road_lanes: bool = kwargs.get("incl_road_lanes", True) ++ incl_road_areas: bool = kwargs.get("incl_road_areas", False) ++ incl_ped_crosswalks: bool = kwargs.get("incl_ped_crosswalks", False) ++ incl_ped_walkways: bool = kwargs.get("incl_ped_walkways", False) ++ ++ # Add any map offset in case the map origin was shifted for storage efficiency. ++ shifted_origin: np.ndarray = np.array( ++ [ ++ vec_map.shifted_origin.x, ++ vec_map.shifted_origin.y, ++ vec_map.shifted_origin.z, ++ 0.0, # Some polylines also have heading so we're adding ++ # this (zero) coordinate to account for that. ++ ] ++ ) ++ ++ map_elem_dict: Dict[str, Dict[str, MapElement]] = defaultdict(dict) ++ ++ map_elem: MapElement ++ for map_elem in vec_map.elements: ++ elem_id: str = map_elem.id.decode() ++ if incl_road_lanes and map_elem.HasField("road_lane"): ++ road_lane_obj: map_proto.RoadLane = map_elem.road_lane ++ ++ center_pl: Polyline = Polyline( ++ map_utils.proto_to_np(road_lane_obj.center) + shifted_origin ++ ) ++ ++ # We do not care for the heading of the left and right edges ++ # (only the center matters). ++ left_pl: Optional[Polyline] = None ++ if road_lane_obj.HasField("left_boundary"): ++ left_pl = Polyline( ++ map_utils.proto_to_np( ++ road_lane_obj.left_boundary, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ right_pl: Optional[Polyline] = None ++ if road_lane_obj.HasField("right_boundary"): ++ right_pl = Polyline( ++ map_utils.proto_to_np( ++ road_lane_obj.right_boundary, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ adj_lanes_left: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.adjacent_lanes_left] ++ ) ++ adj_lanes_right: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.adjacent_lanes_right] ++ ) ++ ++ next_lanes: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.exit_lanes] ++ ) ++ prev_lanes: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.entry_lanes] ++ ) ++ ++ # Double-using the connectivity attributes for lane IDs now (will ++ # replace them with Lane objects after all Lane objects have been created). ++ curr_lane = RoadLane( ++ elem_id, ++ center_pl, ++ left_pl, ++ right_pl, ++ adj_lanes_left, ++ adj_lanes_right, ++ next_lanes, ++ prev_lanes, ++ ) ++ map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane ++ ++ elif incl_road_areas and map_elem.HasField("road_area"): ++ road_area_obj: map_proto.RoadArea = map_elem.road_area ++ ++ exterior: Polyline = Polyline( ++ map_utils.proto_to_np( ++ road_area_obj.exterior_polygon, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ interior_holes: List[Polyline] = list() ++ interior_hole: map_proto.Polyline ++ for interior_hole in road_area_obj.interior_holes: ++ interior_holes.append( ++ Polyline( ++ map_utils.proto_to_np(interior_hole, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ) ++ ++ curr_area = RoadArea(elem_id, exterior, interior_holes) ++ map_elem_dict[MapElementType.ROAD_AREA][elem_id] = curr_area ++ ++ elif incl_ped_crosswalks and map_elem.HasField("ped_crosswalk"): ++ ped_crosswalk_obj: map_proto.PedCrosswalk = map_elem.ped_crosswalk ++ ++ polygon_vertices: Polyline = Polyline( ++ map_utils.proto_to_np(ped_crosswalk_obj.polygon, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ++ curr_area = PedCrosswalk(elem_id, polygon_vertices) ++ map_elem_dict[MapElementType.PED_CROSSWALK][elem_id] = curr_area ++ ++ elif incl_ped_walkways and map_elem.HasField("ped_walkway"): ++ ped_walkway_obj: map_proto.PedCrosswalk = map_elem.ped_walkway ++ ++ polygon_vertices: Polyline = Polyline( ++ map_utils.proto_to_np(ped_walkway_obj.polygon, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ++ curr_area = PedWalkway(elem_id, polygon_vertices) ++ map_elem_dict[MapElementType.PED_WALKWAY][elem_id] = curr_area ++ ++ return cls( ++ map_id=vec_map.name, ++ extent=np.array( ++ [ ++ vec_map.min_pt.x, ++ vec_map.min_pt.y, ++ vec_map.min_pt.z, ++ vec_map.max_pt.x, ++ vec_map.max_pt.y, ++ vec_map.max_pt.z, ++ ] ++ ), ++ elements=map_elem_dict, ++ search_kdtrees=None, ++ traffic_light_status=None, ++ ) ++ ++ def associate_scene_data( ++ self, traffic_light_status_dict: Dict[Tuple[int, int], TrafficLightStatus] ++ ) -> None: ++ """Associates vector map with scene-specific data like traffic light information""" ++ self.traffic_light_status = traffic_light_status_dict ++ ++ def get_current_lane( ++ self, ++ xyzh: np.ndarray, ++ max_dist: float = 2.0, ++ max_heading_error: float = np.pi / 8, ++ ) -> List[RoadLane]: ++ """ ++ Args: ++ xyzh (np.ndarray): 3d position and heading of agent in world coordinates ++ ++ Returns: ++ List[RoadLane]: List of possible road lanes that agent could be on ++ """ ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return [ ++ self.lanes[idx] ++ for idx in lane_kdtree.current_lane_inds(xyzh, max_dist, max_heading_error) ++ ] ++ ++ def get_closest_lane(self, xyz: np.ndarray) -> RoadLane: ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return self.lanes[lane_kdtree.closest_polyline_ind(xyz)] ++ ++ def get_closest_unique_lanes(self, xyz_vec: np.ndarray) -> List[RoadLane]: ++ assert xyz_vec.ndim == 2 # xyz_vec is assumed to be (*, 3) ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ closest_inds = lane_kdtree.closest_polyline_ind(xyz_vec) ++ unique_inds = np.unique(closest_inds) ++ return [self.lanes[ind] for ind in unique_inds] ++ ++ def get_lanes_within(self, xyz: np.ndarray, dist: float) -> List[RoadLane]: ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return [ ++ self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) ++ ] ++ ++ def get_traffic_light_status( ++ self, lane_id: str, scene_ts: int ++ ) -> TrafficLightStatus: ++ return ( ++ self.traffic_light_status.get( ++ (int(lane_id), scene_ts), TrafficLightStatus.NO_DATA ++ ) ++ if self.traffic_light_status is not None ++ else TrafficLightStatus.NO_DATA ++ ) ++ ++ def rasterize( ++ self, resolution: float = 2, **kwargs ++ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: ++ """Renders this vector map at the specified resolution. ++ ++ Args: ++ resolution (float): The rasterized image's resolution in pixels per meter. ++ ++ Returns: ++ np.ndarray: The rasterized RGB image. ++ """ ++ return_tf_mat: bool = kwargs.get("return_tf_mat", False) ++ incl_centerlines: bool = kwargs.get("incl_centerlines", True) ++ incl_lane_edges: bool = kwargs.get("incl_lane_edges", True) ++ incl_lane_area: bool = kwargs.get("incl_lane_area", True) ++ ++ scene_ts: Optional[int] = kwargs.get("scene_ts", None) ++ ++ # (255, 102, 99) also looks nice. ++ center_color: Tuple[int, int, int] = kwargs.get("center_color", (129, 51, 255)) ++ # (86, 203, 249) also looks nice. ++ edge_color: Tuple[int, int, int] = kwargs.get("edge_color", (118, 185, 0)) ++ # (191, 215, 234) also looks nice. ++ area_color: Tuple[int, int, int] = kwargs.get("area_color", (214, 232, 181)) ++ ++ min_x, min_y, _, max_x, max_y, _ = self.extent ++ ++ world_center_m: Tuple[float, float] = ( ++ (max_x + min_x) / 2, ++ (max_y + min_y) / 2, ++ ) ++ ++ raster_size_x: int = ceil((max_x - min_x) * resolution) ++ raster_size_y: int = ceil((max_y - min_y) * resolution) ++ ++ raster_from_local: np.ndarray = np.array( ++ [ ++ [resolution, 0, raster_size_x / 2], ++ [0, resolution, raster_size_y / 2], ++ [0, 0, 1], ++ ] ++ ) ++ ++ # Compute pose from its position and rotation. ++ pose_from_world: np.ndarray = np.array( ++ [ ++ [1, 0, -world_center_m[0]], ++ [0, 1, -world_center_m[1]], ++ [0, 0, 1], ++ ] ++ ) ++ ++ raster_from_world: np.ndarray = raster_from_local @ pose_from_world ++ ++ map_img: np.ndarray = np.zeros( ++ shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 ++ ) ++ ++ lane_edges: List[np.ndarray] = list() ++ centerlines: List[np.ndarray] = list() ++ lane: RoadLane ++ for lane in tqdm( ++ self.elements[MapElementType.ROAD_LANE].values(), ++ desc=f"Rasterizing Map at {resolution:.2f} px/m", ++ leave=False, ++ ): ++ centerlines.append( ++ raster_utils.world_to_subpixel( ++ lane.center.points[:, :2], raster_from_world ++ ) ++ ) ++ if lane.left_edge is not None and lane.right_edge is not None: ++ left_pts: np.ndarray = lane.left_edge.points[:, :2] ++ right_pts: np.ndarray = lane.right_edge.points[:, :2] ++ ++ lane_edges += [ ++ raster_utils.world_to_subpixel(left_pts, raster_from_world), ++ raster_utils.world_to_subpixel(right_pts, raster_from_world), ++ ] ++ ++ lane_color = area_color ++ status = self.get_traffic_light_status(lane.id, scene_ts) ++ if status == TrafficLightStatus.GREEN: ++ lane_color = [0, 200, 0] ++ elif status == TrafficLightStatus.RED: ++ lane_color = [200, 0, 0] ++ elif status == TrafficLightStatus.UNKNOWN: ++ lane_color = [150, 150, 0] ++ ++ # Drawing lane areas. Need to do per loop because doing it all at once can ++ # create lots of wonky holes in the image. ++ # See https://stackoverflow.com/questions/69768620/cv2-fillpoly-failing-for-intersecting-polygons ++ if incl_lane_area: ++ lane_area: np.ndarray = np.concatenate( ++ [left_pts, right_pts[::-1]], axis=0 ++ ) ++ raster_utils.rasterize_world_polygon( ++ lane_area, ++ map_img, ++ raster_from_world, ++ color=lane_color, ++ ) ++ ++ # Drawing all lane edge lines at the same time. ++ if incl_lane_edges: ++ raster_utils.cv2_draw_polylines(lane_edges, map_img, color=edge_color) ++ ++ # Drawing centerlines last (on top of everything else). ++ if incl_centerlines: ++ raster_utils.cv2_draw_polylines(centerlines, map_img, color=center_color) ++ ++ if return_tf_mat: ++ return map_img.astype(float) / 255, raster_from_world ++ else: ++ return map_img.astype(float) / 255 ++ ++ @overload ++ def visualize_lane_graph( ++ self, ++ origin_lane: RoadLane, ++ num_hops: int, ++ **kwargs, ++ ) -> Axes: ++ ... ++ ++ @overload ++ def visualize_lane_graph(self, origin_lane: str, num_hops: int, **kwargs) -> Axes: ++ ... ++ ++ @overload ++ def visualize_lane_graph(self, origin_lane: int, num_hops: int, **kwargs) -> Axes: ++ ... ++ ++ def visualize_lane_graph( ++ self, origin_lane: Union[RoadLane, str, int], num_hops: int, **kwargs ++ ) -> Axes: ++ ax = kwargs.get("ax", None) ++ if ax is None: ++ fig, ax = plt.subplots() ++ ++ origin: str ++ if isinstance(origin_lane, RoadLane): ++ origin = origin_lane.id ++ elif isinstance(origin_lane, str): ++ origin = origin_lane ++ elif isinstance(origin_lane, int): ++ origin = self.lanes[origin_lane].id ++ ++ viridis = mpl.colormaps[kwargs.get("cmap", "rainbow")].resampled(num_hops + 1) ++ ++ already_seen: Set[str] = set() ++ lanes_to_plot: List[Tuple[str, int]] = [(origin, 0)] ++ ++ if kwargs.get("legend", True): ++ ax.scatter([], [], label=f"Lane Endpoints", color="k") ++ ax.plot([], [], label=f"Origin Lane ({origin})", color=viridis(0)) ++ for h in range(1, num_hops + 1): ++ ax.plot( ++ [], ++ [], ++ label=f"{h} Lane{'s' if h > 1 else ''} Away", ++ color=viridis(h), ++ ) ++ ++ raster_from_world = kwargs.get("raster_from_world", None) ++ while len(lanes_to_plot) > 0: ++ lane_id, curr_hops = lanes_to_plot.pop(0) ++ already_seen.add(lane_id) ++ lane: RoadLane = self.get_road_lane(lane_id) ++ ++ center: np.ndarray = lane.center.points[..., :2] ++ first_pt_heading: float = lane.center.points[0, -1] ++ mdpt: np.ndarray = lane.center.midpoint[..., :2] ++ ++ if raster_from_world is not None: ++ center = map_utils.transform_points(center, raster_from_world) ++ mdpt = map_utils.transform_points(mdpt[None, :], raster_from_world)[0] ++ ++ ax.plot(center[:, 0], center[:, 1], color=viridis(curr_hops)) ++ ax.scatter(center[[0, -1], 0], center[[0, -1], 1], color=viridis(curr_hops)) ++ ax.quiver( ++ [center[0, 0]], ++ [center[0, 1]], ++ [np.cos(first_pt_heading)], ++ [np.sin(first_pt_heading)], ++ color=viridis(curr_hops), ++ ) ++ ax.text(mdpt[0], mdpt[1], s=lane_id) ++ ++ if curr_hops < num_hops: ++ lanes_to_plot += [ ++ (l, curr_hops + 1) ++ for l in lane.reachable_lanes ++ if l not in already_seen ++ ] ++ ++ if kwargs.get("legend", True): ++ ax.legend(loc="best", frameon=True) ++ ++ return ax +\ No newline at end of file +diff --git a/src/trajdata/proto/vectorized_map.proto b/src/trajdata/proto/vectorized_map.proto +index e1cd502..8f68281 100644 +--- a/src/trajdata/proto/vectorized_map.proto ++++ b/src/trajdata/proto/vectorized_map.proto +@@ -39,14 +39,14 @@ message Point { + } + + message Polyline { +- // Position deltas in millimeters. The origin is an arbitrary location. +- // From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 ++ // Position deltas in 10^-5 meters. The origin is an arbitrary location. ++ // Inspired by https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 + // The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + // the origin. For subsequent points, this field stores the difference between the point's + // coordinates and the previous point's coordinates. This is for representation efficiency. +- repeated sint32 dx_mm = 1; +- repeated sint32 dy_mm = 2; +- repeated sint32 dz_mm = 3; ++ repeated sint64 dx_mm = 1; ++ repeated sint64 dy_mm = 2; ++ repeated sint64 dz_mm = 3; + repeated double h_rad = 4; + } + +@@ -74,6 +74,9 @@ message RoadLane { + // A list of neighbors to the right of this lane. Neighbor lanes + // include only adjacent lanes going the same direction. + repeated bytes adjacent_lanes_right = 7; ++ ++ // A list of associated road area ids. ++ repeated bytes road_area_ids = 8; + } + + message RoadArea { +diff --git a/src/trajdata/proto/vectorized_map_pb2.py b/src/trajdata/proto/vectorized_map_pb2.py +index 7352109..c5910d9 100644 +--- a/src/trajdata/proto/vectorized_map_pb2.py ++++ b/src/trajdata/proto/vectorized_map_pb2.py +@@ -1,136 +1,545 @@ + # -*- coding: utf-8 -*- + # Generated by the protocol buffer compiler. DO NOT EDIT! + # source: vectorized_map.proto +-"""Generated protocol buffer code.""" ++ + from google.protobuf import descriptor as _descriptor +-from google.protobuf import descriptor_pool as _descriptor_pool + from google.protobuf import message as _message + from google.protobuf import reflection as _reflection + from google.protobuf import symbol_database as _symbol_database +- + # @@protoc_insertion_point(imports) + + _sym_db = _symbol_database.Default() + + +-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( +- b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' ++ ++ ++DESCRIPTOR = _descriptor.FileDescriptor( ++ name='vectorized_map.proto', ++ package='trajdata', ++ syntax='proto3', ++ serialized_options=None, ++ create_key=_descriptor._internal_create_key, ++ serialized_pb=b'\n\x14vectorized_map.proto\x12\x08trajdata\"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point\"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data\"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x12\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x12\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x12\x12\r\n\x05h_rad\x18\x04 \x03(\x01\"\xaf\x02\n\x08RoadLane\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x12\x15\n\rroad_area_ids\x18\x08 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary\"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline\"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' + ) + + +-_VECTORIZEDMAP = DESCRIPTOR.message_types_by_name["VectorizedMap"] +-_MAPELEMENT = DESCRIPTOR.message_types_by_name["MapElement"] +-_POINT = DESCRIPTOR.message_types_by_name["Point"] +-_POLYLINE = DESCRIPTOR.message_types_by_name["Polyline"] +-_ROADLANE = DESCRIPTOR.message_types_by_name["RoadLane"] +-_ROADAREA = DESCRIPTOR.message_types_by_name["RoadArea"] +-_PEDCROSSWALK = DESCRIPTOR.message_types_by_name["PedCrosswalk"] +-_PEDWALKWAY = DESCRIPTOR.message_types_by_name["PedWalkway"] +-VectorizedMap = _reflection.GeneratedProtocolMessageType( +- "VectorizedMap", +- (_message.Message,), +- { +- "DESCRIPTOR": _VECTORIZEDMAP, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) +- }, ++ ++ ++_VECTORIZEDMAP = _descriptor.Descriptor( ++ name='VectorizedMap', ++ full_name='trajdata.VectorizedMap', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='name', full_name='trajdata.VectorizedMap.name', index=0, ++ number=1, type=9, cpp_type=9, label=1, ++ has_default_value=False, default_value=b"".decode('utf-8'), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='elements', full_name='trajdata.VectorizedMap.elements', index=1, ++ number=2, type=11, cpp_type=10, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='max_pt', full_name='trajdata.VectorizedMap.max_pt', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='min_pt', full_name='trajdata.VectorizedMap.min_pt', index=3, ++ number=4, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='shifted_origin', full_name='trajdata.VectorizedMap.shifted_origin', index=4, ++ number=5, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=35, ++ serialized_end=211, + ) +-_sym_db.RegisterMessage(VectorizedMap) + +-MapElement = _reflection.GeneratedProtocolMessageType( +- "MapElement", +- (_message.Message,), +- { +- "DESCRIPTOR": _MAPELEMENT, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.MapElement) +- }, ++ ++_MAPELEMENT = _descriptor.Descriptor( ++ name='MapElement', ++ full_name='trajdata.MapElement', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='id', full_name='trajdata.MapElement.id', index=0, ++ number=1, type=12, cpp_type=9, label=1, ++ has_default_value=False, default_value=b"", ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_lane', full_name='trajdata.MapElement.road_lane', index=1, ++ number=2, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_area', full_name='trajdata.MapElement.road_area', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='ped_crosswalk', full_name='trajdata.MapElement.ped_crosswalk', index=3, ++ number=4, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='ped_walkway', full_name='trajdata.MapElement.ped_walkway', index=4, ++ number=5, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ _descriptor.OneofDescriptor( ++ name='element_data', full_name='trajdata.MapElement.element_data', ++ index=0, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ ], ++ serialized_start=214, ++ serialized_end=430, + ) +-_sym_db.RegisterMessage(MapElement) + +-Point = _reflection.GeneratedProtocolMessageType( +- "Point", +- (_message.Message,), +- { +- "DESCRIPTOR": _POINT, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.Point) +- }, ++ ++_POINT = _descriptor.Descriptor( ++ name='Point', ++ full_name='trajdata.Point', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='x', full_name='trajdata.Point.x', index=0, ++ number=1, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='y', full_name='trajdata.Point.y', index=1, ++ number=2, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='z', full_name='trajdata.Point.z', index=2, ++ number=3, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=432, ++ serialized_end=472, + ) +-_sym_db.RegisterMessage(Point) + +-Polyline = _reflection.GeneratedProtocolMessageType( +- "Polyline", +- (_message.Message,), +- { +- "DESCRIPTOR": _POLYLINE, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.Polyline) +- }, ++ ++_POLYLINE = _descriptor.Descriptor( ++ name='Polyline', ++ full_name='trajdata.Polyline', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='dx_mm', full_name='trajdata.Polyline.dx_mm', index=0, ++ number=1, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='dy_mm', full_name='trajdata.Polyline.dy_mm', index=1, ++ number=2, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='dz_mm', full_name='trajdata.Polyline.dz_mm', index=2, ++ number=3, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='h_rad', full_name='trajdata.Polyline.h_rad', index=3, ++ number=4, type=1, cpp_type=5, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=474, ++ serialized_end=544, + ) +-_sym_db.RegisterMessage(Polyline) + +-RoadLane = _reflection.GeneratedProtocolMessageType( +- "RoadLane", +- (_message.Message,), +- { +- "DESCRIPTOR": _ROADLANE, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.RoadLane) +- }, ++ ++_ROADLANE = _descriptor.Descriptor( ++ name='RoadLane', ++ full_name='trajdata.RoadLane', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='center', full_name='trajdata.RoadLane.center', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='left_boundary', full_name='trajdata.RoadLane.left_boundary', index=1, ++ number=2, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='right_boundary', full_name='trajdata.RoadLane.right_boundary', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='entry_lanes', full_name='trajdata.RoadLane.entry_lanes', index=3, ++ number=4, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='exit_lanes', full_name='trajdata.RoadLane.exit_lanes', index=4, ++ number=5, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='adjacent_lanes_left', full_name='trajdata.RoadLane.adjacent_lanes_left', index=5, ++ number=6, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='adjacent_lanes_right', full_name='trajdata.RoadLane.adjacent_lanes_right', index=6, ++ number=7, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_area_ids', full_name='trajdata.RoadLane.road_area_ids', index=7, ++ number=8, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ _descriptor.OneofDescriptor( ++ name='_left_boundary', full_name='trajdata.RoadLane._left_boundary', ++ index=0, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ _descriptor.OneofDescriptor( ++ name='_right_boundary', full_name='trajdata.RoadLane._right_boundary', ++ index=1, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ ], ++ serialized_start=547, ++ serialized_end=850, + ) +-_sym_db.RegisterMessage(RoadLane) + +-RoadArea = _reflection.GeneratedProtocolMessageType( +- "RoadArea", +- (_message.Message,), +- { +- "DESCRIPTOR": _ROADAREA, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.RoadArea) +- }, ++ ++_ROADAREA = _descriptor.Descriptor( ++ name='RoadArea', ++ full_name='trajdata.RoadArea', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='exterior_polygon', full_name='trajdata.RoadArea.exterior_polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='interior_holes', full_name='trajdata.RoadArea.interior_holes', index=1, ++ number=2, type=11, cpp_type=10, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=852, ++ serialized_end=952, + ) +-_sym_db.RegisterMessage(RoadArea) + +-PedCrosswalk = _reflection.GeneratedProtocolMessageType( +- "PedCrosswalk", +- (_message.Message,), +- { +- "DESCRIPTOR": _PEDCROSSWALK, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) +- }, ++ ++_PEDCROSSWALK = _descriptor.Descriptor( ++ name='PedCrosswalk', ++ full_name='trajdata.PedCrosswalk', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='polygon', full_name='trajdata.PedCrosswalk.polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=954, ++ serialized_end=1005, + ) +-_sym_db.RegisterMessage(PedCrosswalk) + +-PedWalkway = _reflection.GeneratedProtocolMessageType( +- "PedWalkway", +- (_message.Message,), +- { +- "DESCRIPTOR": _PEDWALKWAY, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) +- }, ++ ++_PEDWALKWAY = _descriptor.Descriptor( ++ name='PedWalkway', ++ full_name='trajdata.PedWalkway', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='polygon', full_name='trajdata.PedWalkway.polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=1007, ++ serialized_end=1056, + ) ++ ++_VECTORIZEDMAP.fields_by_name['elements'].message_type = _MAPELEMENT ++_VECTORIZEDMAP.fields_by_name['max_pt'].message_type = _POINT ++_VECTORIZEDMAP.fields_by_name['min_pt'].message_type = _POINT ++_VECTORIZEDMAP.fields_by_name['shifted_origin'].message_type = _POINT ++_MAPELEMENT.fields_by_name['road_lane'].message_type = _ROADLANE ++_MAPELEMENT.fields_by_name['road_area'].message_type = _ROADAREA ++_MAPELEMENT.fields_by_name['ped_crosswalk'].message_type = _PEDCROSSWALK ++_MAPELEMENT.fields_by_name['ped_walkway'].message_type = _PEDWALKWAY ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['road_lane']) ++_MAPELEMENT.fields_by_name['road_lane'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['road_area']) ++_MAPELEMENT.fields_by_name['road_area'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['ped_crosswalk']) ++_MAPELEMENT.fields_by_name['ped_crosswalk'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['ped_walkway']) ++_MAPELEMENT.fields_by_name['ped_walkway'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_ROADLANE.fields_by_name['center'].message_type = _POLYLINE ++_ROADLANE.fields_by_name['left_boundary'].message_type = _POLYLINE ++_ROADLANE.fields_by_name['right_boundary'].message_type = _POLYLINE ++_ROADLANE.oneofs_by_name['_left_boundary'].fields.append( ++ _ROADLANE.fields_by_name['left_boundary']) ++_ROADLANE.fields_by_name['left_boundary'].containing_oneof = _ROADLANE.oneofs_by_name['_left_boundary'] ++_ROADLANE.oneofs_by_name['_right_boundary'].fields.append( ++ _ROADLANE.fields_by_name['right_boundary']) ++_ROADLANE.fields_by_name['right_boundary'].containing_oneof = _ROADLANE.oneofs_by_name['_right_boundary'] ++_ROADAREA.fields_by_name['exterior_polygon'].message_type = _POLYLINE ++_ROADAREA.fields_by_name['interior_holes'].message_type = _POLYLINE ++_PEDCROSSWALK.fields_by_name['polygon'].message_type = _POLYLINE ++_PEDWALKWAY.fields_by_name['polygon'].message_type = _POLYLINE ++DESCRIPTOR.message_types_by_name['VectorizedMap'] = _VECTORIZEDMAP ++DESCRIPTOR.message_types_by_name['MapElement'] = _MAPELEMENT ++DESCRIPTOR.message_types_by_name['Point'] = _POINT ++DESCRIPTOR.message_types_by_name['Polyline'] = _POLYLINE ++DESCRIPTOR.message_types_by_name['RoadLane'] = _ROADLANE ++DESCRIPTOR.message_types_by_name['RoadArea'] = _ROADAREA ++DESCRIPTOR.message_types_by_name['PedCrosswalk'] = _PEDCROSSWALK ++DESCRIPTOR.message_types_by_name['PedWalkway'] = _PEDWALKWAY ++_sym_db.RegisterFileDescriptor(DESCRIPTOR) ++ ++VectorizedMap = _reflection.GeneratedProtocolMessageType('VectorizedMap', (_message.Message,), { ++ 'DESCRIPTOR' : _VECTORIZEDMAP, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) ++ }) ++_sym_db.RegisterMessage(VectorizedMap) ++ ++MapElement = _reflection.GeneratedProtocolMessageType('MapElement', (_message.Message,), { ++ 'DESCRIPTOR' : _MAPELEMENT, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.MapElement) ++ }) ++_sym_db.RegisterMessage(MapElement) ++ ++Point = _reflection.GeneratedProtocolMessageType('Point', (_message.Message,), { ++ 'DESCRIPTOR' : _POINT, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.Point) ++ }) ++_sym_db.RegisterMessage(Point) ++ ++Polyline = _reflection.GeneratedProtocolMessageType('Polyline', (_message.Message,), { ++ 'DESCRIPTOR' : _POLYLINE, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.Polyline) ++ }) ++_sym_db.RegisterMessage(Polyline) ++ ++RoadLane = _reflection.GeneratedProtocolMessageType('RoadLane', (_message.Message,), { ++ 'DESCRIPTOR' : _ROADLANE, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.RoadLane) ++ }) ++_sym_db.RegisterMessage(RoadLane) ++ ++RoadArea = _reflection.GeneratedProtocolMessageType('RoadArea', (_message.Message,), { ++ 'DESCRIPTOR' : _ROADAREA, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.RoadArea) ++ }) ++_sym_db.RegisterMessage(RoadArea) ++ ++PedCrosswalk = _reflection.GeneratedProtocolMessageType('PedCrosswalk', (_message.Message,), { ++ 'DESCRIPTOR' : _PEDCROSSWALK, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) ++ }) ++_sym_db.RegisterMessage(PedCrosswalk) ++ ++PedWalkway = _reflection.GeneratedProtocolMessageType('PedWalkway', (_message.Message,), { ++ 'DESCRIPTOR' : _PEDWALKWAY, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) ++ }) + _sym_db.RegisterMessage(PedWalkway) + +-if _descriptor._USE_C_DESCRIPTORS == False: +- +- DESCRIPTOR._options = None +- _VECTORIZEDMAP._serialized_start = 35 +- _VECTORIZEDMAP._serialized_end = 211 +- _MAPELEMENT._serialized_start = 214 +- _MAPELEMENT._serialized_end = 430 +- _POINT._serialized_start = 432 +- _POINT._serialized_end = 472 +- _POLYLINE._serialized_start = 474 +- _POLYLINE._serialized_end = 544 +- _ROADLANE._serialized_start = 547 +- _ROADLANE._serialized_end = 827 +- _ROADAREA._serialized_start = 829 +- _ROADAREA._serialized_end = 929 +- _PEDCROSSWALK._serialized_start = 931 +- _PEDCROSSWALK._serialized_end = 982 +- _PEDWALKWAY._serialized_start = 984 +- _PEDWALKWAY._serialized_end = 1033 ++ + # @@protoc_insertion_point(module_scope) +diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py +index e76a678..6f10bb4 100644 +--- a/src/trajdata/utils/arr_utils.py ++++ b/src/trajdata/utils/arr_utils.py +@@ -93,7 +93,9 @@ def vrange(starts: np.ndarray, stops: np.ndarray) -> np.ndarray: + return np.repeat(stops - lens.cumsum(), lens) + np.arange(lens.sum()) + + +-def angle_wrap(radians: np.ndarray) -> np.ndarray: ++def angle_wrap( ++ radians: Union[np.ndarray, torch.Tensor] ++) -> Union[np.ndarray, torch.Tensor]: + """This function wraps angles to lie within [-pi, pi). + + Args: +@@ -130,12 +132,12 @@ def rotation_matrix(angle: Union[float, np.ndarray]) -> np.ndarray: + return rotmat.transpose(*np.arange(2, batch_dims + 2), 0, 1) + + +-def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: ++def transform_matrices(angles: Tensor, translations: Optional[Tensor]) -> Tensor: + """Creates a 3x3 transformation matrix for each angle and translation in the input. + + Args: +- angles (Tensor): The (N,)-shaped angles tensor to rotate points by (in radians). +- translations (Tensor): The (N,2)-shaped translations to shift points by. ++ angles (Tensor): The (...)-shaped angles tensor to rotate points by (in radians). ++ translations (Tensor): The (...,2)-shaped translations to shift points by. + + Returns: + Tensor: The Nx3x3 transformation matrices. +@@ -143,12 +145,19 @@ def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: + cos_vals = torch.cos(angles) + sin_vals = torch.sin(angles) + last_rows = torch.tensor( +- [[0.0, 0.0, 1.0]], dtype=angles.dtype, device=angles.device +- ).expand((angles.shape[0], -1)) ++ [0.0, 0.0, 1.0], dtype=angles.dtype, device=angles.device ++ ).view([1] * angles.ndim + [3]).expand(list(angles.shape) + [-1]) ++ ++ if translations is None: ++ trans_x = torch.zeros_like(angles) ++ trans_y = trans_x ++ else: ++ trans_x, trans_y = torch.unbind(translations, dim=-1) ++ + return torch.stack( + [ +- torch.stack([cos_vals, -sin_vals, translations[:, 0]], dim=-1), +- torch.stack([sin_vals, cos_vals, translations[:, 1]], dim=-1), ++ torch.stack([cos_vals, -sin_vals, trans_x], dim=-1), ++ torch.stack([sin_vals, cos_vals, trans_y], dim=-1), + last_rows, + ], + dim=-2, +@@ -243,6 +252,125 @@ def transform_xyh_np(xyh: np.ndarray, tf_mat: np.ndarray) -> np.ndarray: + transformed_angles = transform_angles_np(xyh[..., 2], tf_mat) + return np.concatenate([transformed_xy, transformed_angles[..., None]], axis=-1) + ++def transform_xyh_torch(xyh: torch.Tensor, tf_mat: torch.Tensor) -> torch.Tensor: ++ """ ++ Returns transformed set of xyh points ++ ++ Args: ++ xyh (torch.Tensor): shape [...,3] ++ tf_mat (torch.Tensor): shape [...,3,3] ++ """ ++ transformed_xy = batch_nd_transform_points_pt(xyh[..., :2], tf_mat) ++ transformed_angles = batch_nd_transform_angles_pt(xyh[..., 2], tf_mat) ++ return torch.cat([transformed_xy, transformed_angles[..., None]], dim=-1) ++ ++# -------- TODO(pkarkus) redundant transforms, remove them ++ ++ ++def batch_nd_transform_points_np(points: np.ndarray, Mat: np.ndarray) -> np.ndarray: ++ ndim = Mat.shape[-1] - 1 ++ batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] ++ Mat = np.transpose(Mat, batch) ++ if points.ndim == Mat.ndim - 1: ++ return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ ++ ..., -1:, :ndim ++ ].squeeze(-2) ++ elif points.ndim == Mat.ndim: ++ return ( ++ (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) ++ + Mat[..., np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ else: ++ raise Exception("wrong shape") ++ ++def batch_nd_transform_points_pt( ++ points: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ ndim = Mat.shape[-1] - 1 ++ Mat = torch.transpose(Mat, -1, -2) ++ if points.ndim == Mat.ndim - 1: ++ return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ ++ ..., -1:, :ndim ++ ].squeeze(-2) ++ elif points.ndim == Mat.ndim: ++ return ( ++ (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) ++ + Mat[..., np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ elif points.ndim == Mat.ndim + 1: ++ return ( ++ ( ++ points[..., np.newaxis, :] ++ @ Mat[..., np.newaxis, np.newaxis, :ndim, :ndim] ++ ) ++ + Mat[..., np.newaxis, np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ else: ++ raise Exception("wrong shape") ++ ++ ++def batch_nd_transform_angles_np(angles: np.ndarray, Mat: np.ndarray) -> np.ndarray: ++ cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] ++ rot_angle = np.arctan2(sin_vals, cos_vals) ++ angles = angles + rot_angle ++ angles = angle_wrap(angles) ++ return angles ++ ++ ++def batch_nd_transform_angles_pt( ++ angles: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] ++ rot_angle = torch.arctan2(sin_vals, cos_vals) ++ if rot_angle.ndim > angles.ndim: ++ raise ValueError("wrong shape") ++ while rot_angle.ndim < angles.ndim: ++ rot_angle = rot_angle.unsqueeze(-1) ++ angles = angles + rot_angle ++ angles = angle_wrap(angles) ++ return angles ++ ++ ++def batch_nd_transform_points_angles_np( ++ points_angles: np.ndarray, Mat: np.ndarray ++) -> np.ndarray: ++ assert points_angles.shape[-1] == 3 ++ points = batch_nd_transform_points_np(points_angles[..., :2], Mat) ++ angles = batch_nd_transform_angles_np(points_angles[..., 2:3], Mat) ++ points_angles = np.concatenate([points, angles], axis=-1) ++ return points_angles ++ ++ ++def batch_nd_transform_points_angles_pt( ++ points_angles: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ assert points_angles.shape[-1] == 3 ++ points = batch_nd_transform_points_pt(points_angles[..., :2], Mat) ++ angles = batch_nd_transform_angles_pt(points_angles[..., 2:3], Mat) ++ points_angles = torch.concat([points, angles], axis=-1) ++ return points_angles ++ ++ ++def batch_nd_transform_xyvvaahh_pt(traj_xyvvaahh: torch.Tensor, tf: torch.Tensor) -> torch.Tensor: ++ """ ++ traj_xyvvaahh: [..., state_dim] where state_dim = [x, y, vx, vy, ax, ay, sinh, cosh] ++ This is the state representation used in AgentBatch and SceneBatch. ++ """ ++ rot_only_tf = tf.clone() ++ rot_only_tf[..., :2, -1] = 0. ++ ++ xy, vv, aa, hh = torch.split(traj_xyvvaahh, (2, 2, 2, 2), dim=-1) ++ xy = batch_nd_transform_points_pt(xy, tf) ++ vv = batch_nd_transform_points_pt(vv, rot_only_tf) ++ aa = batch_nd_transform_points_pt(aa, rot_only_tf) ++ # hh: sinh, cosh instead of cosh, sinh, so we use flip ++ hh = batch_nd_transform_points_pt(hh.flip(-1), rot_only_tf).flip(-1) ++ ++ return torch.concat((xy, vv, aa, hh), dim=-1) ++ ++ ++# -------- end of redundant transforms ++ + + def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: + values_diff: np.ndarray = np.diff( +@@ -325,3 +453,120 @@ def quaternion_to_yaw(q: np.ndarray): + 2 * (q[..., 0] * q[..., 3] - q[..., 1] * q[..., 2]), + 1 - 2 * (q[..., 2] ** 2 + q[..., 3] ** 2), + ) ++ ++ ++def batch_select( ++ x: torch.Tensor, ++ index: torch.Tensor, ++ batch_dims: int ++) -> torch.Tensor: ++ # Indexing into tensor, treating the first `batch_dims` dimensions as batch. ++ # Kind of: output[..., k] = x[..., index[...]] ++ ++ assert index.ndim >= batch_dims ++ assert index.ndim <= x.ndim ++ assert x.shape[:batch_dims] == index.shape[:batch_dims] ++ ++ batch_shape = x.shape[:batch_dims] ++ x_flat = x.reshape(-1, *x.shape[batch_dims:]) ++ index_flat = index.reshape(-1, *index.shape[batch_dims:]) ++ x_flat = x_flat[torch.arange(x_flat.shape[0]), index_flat] ++ x = x_flat.reshape(*batch_shape, *x_flat.shape[1:]) ++ ++ return x ++ ++ ++def roll_with_tensor(mat: torch.Tensor, shifts: torch.LongTensor, dim: int): ++ if dim < 0: ++ dim = mat.ndim + dim ++ arange1 = torch.arange(mat.shape[dim], device=shifts.device) ++ expanded_shape = [1] * dim + [-1] + [1] * (mat.ndim-dim-1) ++ arange1 = arange1.view(expanded_shape).expand(mat.shape) ++ if shifts.ndim == 1: ++ shifts = shifts.view([1] * (dim-1) + [-1]) ++ # TODO(pkarkus) assert that shift dimenesions either match mat or 1 ++ shifts = shifts.view(list(shifts.shape) + [1] * (mat.ndim-dim)) ++ ++ arange2 = (arange1 - shifts) % mat.shape[dim] ++ # print(arange2) ++ return torch.gather(mat, dim, arange2) ++ ++def round_2pi(x): ++ return (x + np.pi) % (2 * np.pi) - np.pi ++ ++def batch_proj(x, line): ++ # x:[batch,3], line:[batch,N,3] ++ line_length = line.shape[-2] ++ batch_dim = x.ndim - 1 ++ if isinstance(x, torch.Tensor): ++ delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( ++ *([1] * batch_dim), line_length, 1 ++ ) ++ dis = torch.linalg.norm(delta, axis=-1) ++ idx0 = torch.argmin(dis, dim=-1) ++ idx = idx0.view(*line.shape[:-2], 1, 1).repeat( ++ *([1] * (batch_dim + 1)), line.shape[-1] ++ ) ++ line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) ++ dx = x[..., None, 0] - line[..., 0] ++ dy = x[..., None, 1] - line[..., 1] ++ delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( ++ line_min[..., None, 2] ++ ) ++ delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( ++ line_min[..., None, 2] ++ ) ++ ++ delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) ++ ++ return ( ++ delta_x, ++ delta_y, ++ torch.unsqueeze(delta_psi, dim=-1), ++ ) ++ elif isinstance(x, np.ndarray): ++ delta = line[..., 0:2] - np.repeat( ++ x[..., np.newaxis, 0:2], line_length, axis=-2 ++ ) ++ dis = np.linalg.norm(delta, axis=-1) ++ idx0 = np.argmin(dis, axis=-1) ++ idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) ++ line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) ++ dx = x[..., None, 0] - line[..., 0] ++ dy = x[..., None, 1] - line[..., 1] ++ delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( ++ line_min[..., None, 2] ++ ) ++ delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( ++ line_min[..., None, 2] ++ ) ++ ++ delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) ++ return ( ++ delta_x, ++ delta_y, ++ np.expand_dims(delta_psi, axis=-1), ++ ) ++ ++def get_close_lanes(radius,ego_xyh,vec_map,num_pts): ++ # obtain close lanes, their distance to the ego ++ close_lanes = [] ++ while len(close_lanes)==0: ++ close_lanes=vec_map.get_lanes_within(ego_xyh,radius) ++ radius+=20 ++ dis = list() ++ lane_pts = np.stack([lane.center.interpolate(num_pts).points[:,[0,1,3]] for lane in close_lanes],0) ++ dx,dy,dh = batch_proj(ego_xyh[None].repeat(lane_pts.shape[0],0),lane_pts) ++ ++ idx = np.abs(dx).argmin(axis=1) ++ # hausdorff distance to the lane (longitudinal) ++ x_dis = np.take_along_axis(np.abs(dx),idx[:,None],axis=1).squeeze(1) ++ x_dis[(dx.min(1)<0) & (dx.max(1)>0)] = 0 ++ ++ y_dis = np.take_along_axis(np.abs(dy),idx[:,None],axis=1).squeeze(1) ++ ++ # distance metric to the lane (combining x,y) ++ dis = x_dis+y_dis ++ ++ ++ return close_lanes,dis +diff --git a/src/trajdata/utils/batch_utils.py b/src/trajdata/utils/batch_utils.py +index cf63009..555e790 100644 +--- a/src/trajdata/utils/batch_utils.py ++++ b/src/trajdata/utils/batch_utils.py +@@ -1,5 +1,9 @@ + from collections import defaultdict +-from typing import Any, Dict, Iterator, List, Optional, Tuple ++ ++import torch ++ ++from pathlib import Path ++from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + + import numpy as np + from torch.utils.data import Sampler +@@ -10,10 +14,15 @@ from trajdata.data_structures import ( + AgentBatchElement, + AgentDataIndex, + AgentType, ++ SceneBatch, + SceneBatchElement, + SceneTimeAgent, + ) +-from trajdata.data_structures.collation import agent_collate_fn ++from trajdata.data_structures.collation import agent_collate_fn, batch_rotate_raster_maps_for_agents_in_scene ++from trajdata.maps import RasterizedMapPatch ++from trajdata.utils.map_utils import load_map_patch ++from trajdata.utils.arr_utils import batch_nd_transform_xyvvaahh_pt, batch_select, PadDirection ++from trajdata.caching.df_cache import DataFrameCache + + + class SceneTimeBatcher(Sampler): +@@ -173,3 +182,107 @@ def convert_to_agent_batch( + ) + + return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) ++ ++ ++def get_agents_map_patch( ++ cache_path: Path, ++ map_name: str, ++ patch_params: Dict[str, int], ++ agent_world_states_xyh: Union[np.ndarray, torch.Tensor], ++ allow_nan: float = False, ++) -> List[RasterizedMapPatch]: ++ ++ if isinstance(agent_world_states_xyh, torch.Tensor): ++ agent_world_states_xyh = agent_world_states_xyh.cpu().numpy() ++ assert agent_world_states_xyh.ndim == 2 ++ assert agent_world_states_xyh.shape[-1] == 3 ++ ++ desired_patch_size: int = patch_params["map_size_px"] ++ resolution: float = patch_params["px_per_m"] ++ offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) ++ return_rgb: bool = patch_params.get("return_rgb", True) ++ no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) ++ ++ env_name, location_name = map_name.split(':') # assumes map_name format nusc_mini:boston-seaport ++ ++ map_patches = list() ++ ++ ( ++ maps_path, ++ _, ++ _, ++ raster_map_path, ++ raster_metadata_path, ++ ) = DataFrameCache.get_map_paths( ++ cache_path, env_name, location_name, resolution ++ ) ++ ++ for i in range(agent_world_states_xyh.shape[0]): ++ patch_data, raster_from_world_tf, has_data = load_map_patch( ++ raster_map_path, ++ raster_metadata_path, ++ agent_world_states_xyh[i, 0], ++ agent_world_states_xyh[i, 1], ++ desired_patch_size, ++ resolution, ++ offset_xy, ++ agent_world_states_xyh[i, 2], ++ return_rgb, ++ rot_pad_factor=np.sqrt(2), ++ no_map_val=no_map_fill_val, ++ ) ++ map_patches.append( ++ RasterizedMapPatch( ++ data=patch_data, ++ rot_angle=agent_world_states_xyh[i, 2], ++ crop_size=desired_patch_size, ++ resolution=resolution, ++ raster_from_world_tf=raster_from_world_tf, ++ has_data=has_data, ++ ) ++ ) ++ ++ return map_patches ++ ++ ++def get_raster_maps_for_scene_batch(batch: SceneBatch, cache_path: Path, raster_map_params: Dict): ++ ++ # Get current states ++ agent_states = batch.agent_hist.as_format('x,y,xd,yd,xdd,ydd,s,c') ++ if batch.history_pad_dir == PadDirection.AFTER: ++ agent_states = batch_select(agent_states, index=batch.agent_hist_len-1, batch_dims=2) # b, N, t, 8 ++ else: ++ agent_states = agent_states[:, :, -1] ++ ++ agent_world_states_xyvvaahh = batch_nd_transform_xyvvaahh_pt( ++ agent_states.type_as(batch.centered_world_from_agent_tf), ++ batch.centered_world_from_agent_tf ++ ) ++ ++ agent_world_states_xyh = torch.concat(( ++ agent_world_states_xyvvaahh[..., :2], ++ torch.atan2(agent_world_states_xyvvaahh[..., 6:7], agent_world_states_xyvvaahh[..., 7:8])), dim=-1) ++ ++ maps: List[torch.Tensor] = [] ++ maps_resolution: List[torch.Tensor] = [] ++ raster_from_world_tf: List[torch.Tensor] = [] ++ ++ # Collect map patches for all elements and agents into a flat list ++ num_agents: List[int] = [] ++ map_patches: List[RasterizedMapPatch] = [] ++ ++ for b_i in range(agent_world_states_xyh.shape[0]): ++ num_agents.append(batch.num_agents[b_i]) ++ map_patches += get_agents_map_patch( ++ cache_path, batch.map_names[b_i], raster_map_params, agent_world_states_xyh[b_i, :batch.num_agents[b_i]]) ++ ++ # Batch transform map patches and pad ++ ( ++ maps, ++ maps_resolution, ++ raster_from_world_tf ++ ) = batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches, num_agents, agent_world_states_xyh.shape[1], pad_value=np.nan, ++ ) ++ ++ return maps, maps_resolution, raster_from_world_tf +diff --git a/src/trajdata/utils/comm_utils.py b/src/trajdata/utils/comm_utils.py +new file mode 100644 +index 0000000..594ccb0 +--- /dev/null ++++ b/src/trajdata/utils/comm_utils.py +@@ -0,0 +1,24 @@ ++import numpy as np ++import socket ++ ++from contextlib import closing ++from typing import Callable, Optional ++ ++ ++def find_open_port(): ++ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: ++ s.bind(("", 0)) ++ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) ++ return s.getsockname()[1] ++ ++ ++def find_open_port_in_range(start_port, end_port): ++ for port in range(start_port, end_port+1): ++ try: ++ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: ++ s.bind(('localhost', port)) ++ s.listen(1) ++ return port ++ except OSError: ++ continue ++ return None +diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py +index 4726537..a429f68 100644 +--- a/src/trajdata/utils/env_utils.py ++++ b/src/trajdata/utils/env_utils.py +@@ -33,6 +33,7 @@ except ModuleNotFoundError: + # This can happen if the user did not install trajdata + # with the "trajdata[nuplan]" option. + pass ++from trajdata.dataset_specific.drivesim import DrivesimDataset + + try: + from trajdata.dataset_specific.waymo import WaymoDataset +@@ -60,7 +61,10 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: + ) + + if "nuplan" in dataset_name: +- return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) ++ return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) ++ ++ if "drivesim" in dataset_name: ++ return DrivesimDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + + if "waymo" in dataset_name: + return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) +diff --git a/src/trajdata/utils/map_utils.py b/src/trajdata/utils/map_utils.py +index 01844b7..45c08fd 100644 +--- a/src/trajdata/utils/map_utils.py ++++ b/src/trajdata/utils/map_utils.py +@@ -3,19 +3,300 @@ from __future__ import annotations + from typing import TYPE_CHECKING + + if TYPE_CHECKING: +- from trajdata.maps import map_kdtree, vec_map ++ from trajdata.maps import map_kdtree, vec_map, RasterizedMapMetadata, RasterizedMapPatch + + from pathlib import Path +-from typing import Dict, Final, Optional ++from typing import Dict, Final, Optional, Tuple, List, Union + + import dill ++import kornia + import numpy as np ++import math ++import torch ++import zarr + from scipy.stats import circmean + + import trajdata.proto.vectorized_map_pb2 as map_proto + from trajdata.utils import arr_utils + +-MM_PER_M: Final[float] = 1000 ++NUM_DECIMALS: Final[int] = 5 ++COMPRESSION_SCALE: Final[float] = 10**NUM_DECIMALS ++import enum ++ ++class LaneSegRelation(enum.IntEnum): ++ """ ++ Categorical token describing the relationship between an agent and a Lane ++ """ ++ NOTCONNECTED = 0 ++ NEXT = 1 ++ PREV = 2 ++ LEFT = 3 ++ RIGHT = 4 ++ ++def pad_map_patch( ++ patch: np.ndarray, ++ # top, bot, left, right ++ patch_sides: Tuple[int, int, int, int], ++ patch_size: int, ++ map_dims: Tuple[int, int, int], ++) -> np.ndarray: ++ # TODO(pkarkus) remove equivalent function from df_cache ++ ++ if patch.shape[-2:] == (patch_size, patch_size): ++ return patch ++ ++ top, bot, left, right = patch_sides ++ channels, height, width = map_dims ++ ++ # If we're off the map, just return zeros in the ++ # desired size of the patch. ++ if bot <= 0 or top >= height or right <= 0 or left >= width: ++ return np.zeros((channels, patch_size, patch_size)) ++ ++ pad_top, pad_bot, pad_left, pad_right = 0, 0, 0, 0 ++ if top < 0: ++ pad_top = 0 - top ++ if bot >= height: ++ pad_bot = bot - height ++ if left < 0: ++ pad_left = 0 - left ++ if right >= width: ++ pad_right = right - width ++ ++ return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)]) ++ ++ ++def load_map_patch( ++ raster_map_path: Path, ++ raster_metadata_path: Path, ++ world_x: float, ++ world_y: float, ++ desired_patch_size: int, ++ resolution: float, ++ offset_xy: Tuple[float, float], ++ agent_heading: float, ++ return_rgb: bool, ++ rot_pad_factor: float = 1.0, ++ no_map_val: float = 0.0, ++ allow_missing_map: bool = False, ++) -> Tuple[np.ndarray, np.ndarray, bool]: ++ ++ # TODO(pkarkus) remove equivalent function from df_cache ++ ++ if not raster_metadata_path.exists(): ++ if not allow_missing_map: ++ raise ValueError(f"Missing map at {raster_metadata_path}") ++ # This dataset (or location) does not have any maps, ++ # so we return an empty map. ++ patch_size: int = math.ceil((rot_pad_factor * desired_patch_size) / 2) * 2 ++ ++ return ( ++ np.full( ++ (1 if not return_rgb else 3, patch_size, patch_size), ++ fill_value=no_map_val, ++ ), ++ np.eye(3), ++ False, ++ ) ++ ++ with open(raster_metadata_path, "rb") as f: ++ map_info: RasterizedMapMetadata = dill.load(f) ++ ++ raster_from_world_tf: np.ndarray = map_info.map_from_world ++ map_coords: np.ndarray = map_info.map_from_world @ np.array( ++ [world_x, world_y, 1.0] ++ ) ++ map_x, map_y = map_coords[0].item(), map_coords[1].item() ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1.0, 0.0, -map_x], ++ [0.0, 1.0, -map_y], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ # This first size is how much of the map we ++ # need to extract to match the requested metric size (meters x meters) of ++ # the patch. ++ data_patch_size: int = math.ceil( ++ desired_patch_size * map_info.resolution / resolution ++ ) ++ ++ # Incorporating offsets. ++ if offset_xy != (0.0, 0.0): ++ # x is negative here because I am moving the map ++ # center so that the agent ends up where the user wishes ++ # (the agent is pinned from the end user's perspective). ++ map_offset: Tuple[float, float] = ( ++ -offset_xy[0] * data_patch_size // 2, ++ offset_xy[1] * data_patch_size // 2, ++ ) ++ ++ rotated_offset: np.ndarray = ( ++ arr_utils.rotation_matrix(agent_heading) @ map_offset ++ ) ++ ++ off_x = rotated_offset[0] ++ off_y = rotated_offset[1] ++ ++ map_x += off_x ++ map_y += off_y ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1.0, 0.0, -off_x], ++ [0.0, 1.0, -off_y], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ # This is the size of the patch taking into account expansion to allow for ++ # rotation to match the agent's heading. We also ensure the final size is ++ # divisible by two so that the // 2 below does not chop any information off. ++ data_with_rot_pad_size: int = math.ceil((rot_pad_factor * data_patch_size) / 2) * 2 ++ ++ disk_data = zarr.open_array(raster_map_path, mode="r") ++ ++ map_x = round(map_x) ++ map_y = round(map_y) ++ ++ # Half of the patch's side length. ++ half_extent: int = data_with_rot_pad_size // 2 ++ ++ top: int = map_y - half_extent ++ bot: int = map_y + half_extent ++ left: int = map_x - half_extent ++ right: int = map_x + half_extent ++ ++ data_patch: np.ndarray = pad_map_patch( ++ disk_data[ ++ ..., ++ max(top, 0) : min(bot, disk_data.shape[1]), ++ max(left, 0) : min(right, disk_data.shape[2]), ++ ], ++ (top, bot, left, right), ++ data_with_rot_pad_size, ++ disk_data.shape, ++ ) ++ ++ if return_rgb: ++ rgb_groups = map_info.layer_rgb_groups ++ data_patch = np.stack( ++ [ ++ np.amax(data_patch[rgb_groups[0]], axis=0), ++ np.amax(data_patch[rgb_groups[1]], axis=0), ++ np.amax(data_patch[rgb_groups[2]], axis=0), ++ ], ++ ) ++ ++ if desired_patch_size != data_patch_size: ++ scale_factor: float = desired_patch_size / data_patch_size ++ data_patch = ( ++ kornia.geometry.rescale( ++ torch.from_numpy(data_patch).unsqueeze(0), ++ scale_factor, ++ # Default align_corners value, just putting it to remove warnings ++ align_corners=False, ++ antialias=True, ++ ) ++ .squeeze(0) ++ .numpy() ++ ) ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1 / scale_factor, 0.0, 0.0], ++ [0.0, 1 / scale_factor, 0.0], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ return data_patch, raster_from_world_tf, True ++ ++ ++def batch_transform_raster_maps( ++ map_patches: List[RasterizedMapPatch], ++): ++ ++ patch_size: int = map_patches[0].crop_size ++ assert all( ++ x.crop_size == patch_size for x in map_patches ++ ) ++ ++ agents_rasters_from_world_tfs: List[np.ndarray] = [ ++ x.raster_from_world_tf for x in map_patches ++ ] ++ agents_patches: List[np.ndarray] = [x.data for x in map_patches] ++ agents_rot_angles_list: List[float] = [ ++ x.rot_angle for x in map_patches] ++ agents_res_list: List[float] = [x.resolution for x in map_patches] ++ ++ patch_data: torch.Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) ++ agents_rot_angles: torch.Tensor = torch.as_tensor( ++ np.stack(agents_rot_angles_list), dtype=torch.float ++ ) ++ agents_rasters_from_world_tf: torch.Tensor = torch.as_tensor( ++ np.stack(agents_rasters_from_world_tfs), dtype=torch.float ++ ) ++ agents_resolution: torch.Tensor = torch.as_tensor( ++ np.stack(agents_res_list), dtype=torch.float ++ ) ++ ++ patch_size_y, patch_size_x = patch_data.shape[-2:] ++ center_y: int = patch_size_y // 2 ++ center_x: int = patch_size_x // 2 ++ half_extent: int = patch_size // 2 ++ ++ if torch.count_nonzero(agents_rot_angles) == 0: ++ agents_rasters_from_world_tf = torch.bmm( ++ torch.tensor( ++ [ ++ [ ++ [1.0, 0.0, half_extent], ++ [0.0, 1.0, half_extent], ++ [0.0, 0.0, 1.0], ++ ] ++ ], ++ dtype=agents_rasters_from_world_tf.dtype, ++ device=agents_rasters_from_world_tf.device, ++ ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), ++ agents_rasters_from_world_tf, ++ ) ++ ++ rot_crop_patches = patch_data ++ else: ++ agents_rasters_from_world_tf = torch.bmm( ++ arr_utils.transform_matrices( ++ -agents_rot_angles, ++ torch.tensor([[half_extent, half_extent]]).expand( ++ (agents_rot_angles.shape[0], -1) ++ ), ++ ), ++ agents_rasters_from_world_tf, ++ ) ++ ++ # Batch rotating patches by rot_angles. ++ rot_patches: torch.Tensor = kornia.geometry.transform.rotate( ++ patch_data, torch.rad2deg(agents_rot_angles)) ++ ++ # Center cropping via slicing. ++ rot_crop_patches = rot_patches[ ++ ..., ++ center_y - half_extent : center_y + half_extent, ++ center_x - half_extent : center_x + half_extent, ++ ] ++ ++ return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + + + def decompress_values(data: np.ndarray) -> np.ndarray: +@@ -23,11 +304,11 @@ def decompress_values(data: np.ndarray) -> np.ndarray: + # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + # the origin. For subsequent points, this field stores the difference between the point's + # coordinates and the previous point's coordinates. This is for representation efficiency. +- return np.cumsum(data, axis=0, dtype=float) / MM_PER_M ++ return np.cumsum(data, axis=0, dtype=float) / COMPRESSION_SCALE + + + def compress_values(data: np.ndarray) -> np.ndarray: +- return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) ++ return (np.diff(data, axis=0, prepend=0.0) * COMPRESSION_SCALE).astype(np.int64) + + + def get_polyline_headings(points: np.ndarray) -> np.ndarray: +diff --git a/src/trajdata/utils/py_utils.py b/src/trajdata/utils/py_utils.py +new file mode 100644 +index 0000000..c302aab +--- /dev/null ++++ b/src/trajdata/utils/py_utils.py +@@ -0,0 +1,13 @@ ++import hashlib ++import json ++from typing import Dict, List, Set, Tuple, Union ++ ++ ++def hash_dict(o: Union[Dict, List, Tuple, Set]) -> str: ++ """ ++ Makes a hash from a dictionary, list, tuple or set to any level, that contains ++ only other hashable types (including any lists, tuples, sets, and ++ dictionaries). ++ """ ++ string_rep: str = json.dumps(o) ++ return hashlib.sha1(str.encode(string_rep)).hexdigest() +diff --git a/src/trajdata/utils/scene_utils.py b/src/trajdata/utils/scene_utils.py +index 804c498..1c29bc5 100644 +--- a/src/trajdata/utils/scene_utils.py ++++ b/src/trajdata/utils/scene_utils.py +@@ -60,12 +60,13 @@ def interpolate_scene_dt(scene: Scene, desired_dt: float) -> None: + + def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: + dt_ratio: float = desired_dt / scene.dt +- if not dt_ratio.is_integer(): ++ ++ if not is_integer_robust(dt_ratio): + raise ValueError( + f"Cannot subsample scene: {desired_dt} is not integer divisible by {scene.dt} for {str(scene)}" + ) + +- dt_factor: int = int(dt_ratio) ++ dt_factor: int = int(round(dt_ratio)) + + # E.g., the scene is currently at dt = 0.1s (10 Hz), + # but we want desired_dt = 0.5s (2 Hz). +@@ -86,3 +87,6 @@ def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: + scene.dt = desired_dt + # Note we do not touch scene_info.env_metadata.dt, this will serve as our + # source of the "original" data dt information. ++ ++def is_integer_robust(x): ++ return abs(x-round(x))<1e-6 +\ No newline at end of file +diff --git a/src/trajdata/utils/vis_utils.py b/src/trajdata/utils/vis_utils.py +index aab5acb..54fc036 100644 +--- a/src/trajdata/utils/vis_utils.py ++++ b/src/trajdata/utils/vis_utils.py +@@ -1,12 +1,15 @@ + from collections import defaultdict +-from typing import List, Optional, Tuple ++from typing import List, Optional, Tuple, Dict ++ + + import geopandas as gpd + import numpy as np + import pandas as pd + import seaborn as sns + from bokeh.models import ColumnDataSource, GlyphRenderer +-from bokeh.plotting import figure ++from bokeh.plotting import figure, curdoc ++import bokeh ++from bokeh.io import export_png + from shapely.geometry import LineString, Polygon + + from trajdata.data_structures.agent import AgentType +@@ -20,7 +23,12 @@ from trajdata.maps.vec_map_elements import ( + RoadArea, + RoadLane, + ) +-from trajdata.utils.arr_utils import transform_coords_2d_np ++from trajdata.utils.arr_utils import ( ++ transform_coords_2d_np, ++ batch_nd_transform_points_pt, ++ batch_nd_transform_points_np, ++) ++from PIL import Image + + + def apply_default_settings(fig: figure) -> None: +@@ -481,7 +489,7 @@ def draw_map_elems( + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + bbox: Optional[Tuple[float, float, float, float]] = None, +- **kwargs ++ **kwargs, + ) -> Tuple[GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer]: + """_summary_ + +diff --git a/tests/Drivesim_scene_generation.py b/tests/Drivesim_scene_generation.py +new file mode 100644 +index 0000000..fadd074 +--- /dev/null ++++ b/tests/Drivesim_scene_generation.py +@@ -0,0 +1,207 @@ ++from collections import namedtuple ++from typing import Any, List, Optional ++from collections import defaultdict ++import numpy as np ++import pandas as pd ++ ++from trajdata.data_structures.agent import AgentMetadata ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene ++from trajdata import AgentType ++from trajdata.data_structures.agent import FixedExtent ++from trajdata.caching.scene_cache import SceneCache ++from trajdata.caching.env_cache import EnvCache ++from trajdata.caching.df_cache import DataFrameCache ++from trajdata.dataset_specific.scene_records import DrivesimSceneRecord ++from pathlib import Path ++import dill ++from trajdata import MapAPI, VectorMap ++import tbsim.utils.lane_utils as LaneUtils ++from bokeh.plotting import figure, show, save ++import bokeh ++ ++def generate_agentmeta(initial_state,agent_names,agent_extents,T,dt,hist_types): ++ total_agent_data = list() ++ total_agent_metadata = list() ++ for x0,name,extent,htype in zip(initial_state,agent_names,agent_extents,hist_types): ++ agent_meta = AgentMetadata(name=name, ++ agent_type= AgentType.VEHICLE, ++ first_timestep = 0, ++ last_timestep = T-1, ++ extent=extent) ++ x,y,v,yaw = x0 ++ if htype=="constvel": ++ vx = v*np.cos(yaw) ++ vy = v*np.sin(yaw) ++ ax = np.zeros_like(vx) ++ ay = np.zeros_like(vy) ++ yaw = yaw*np.ones(T) ++ scene_ts = np.arange(0,T) ++ x = x+vx*(scene_ts-T+1)*dt ++ y = y+vy*(scene_ts-T+1)*dt ++ elif htype in ["brake","accelerate"]: ++ acce = -2.0 if htype=="brake" else 2.0 ++ seq = np.concatenate([np.ones(T-int(T/2))*int(T/2),np.arange(int(T/2)-1,-1,-1)]) ++ vseq = np.clip(v-acce*seq*dt,0.0,10.0) ++ vx = vseq*np.cos(yaw) ++ vy = vseq*np.sin(yaw) ++ ax = vx[1:]-vx[:-1] ++ ax = np.concatenate([ax,ax[-1:]]) ++ ay = vy[1:]-vy[:-1] ++ ay = np.concatenate([ay,ay[-1:]]) ++ yaw = yaw*np.ones(T) ++ scene_ts = np.arange(0,T) ++ x = x+(vx.cumsum()-vx.sum())*dt ++ y = y+(vy.cumsum()-vy.sum())*dt ++ ++ track_id = [name]*T ++ z = np.zeros(T) ++ ++ ++ pd_frame = pd.DataFrame({"agent_id":track_id,"scene_ts":scene_ts,"x":x,"y":y,"z":z,"heading":yaw,"vx":vx,"vy":vy,"ax":ax,"ay":ay}) ++ total_agent_data.append(pd_frame) ++ total_agent_metadata.append(agent_meta) ++ total_agent_data = pd.concat(total_agent_data).set_index(["agent_id", "scene_ts"]) ++ return total_agent_data,total_agent_metadata ++ ++ ++ ++def generate_drivesim_scene(): ++ ++ repeat = 10 ++ ++ dt = 0.1 ++ data_dir = "" ++ T = 20 ++ cache_path = Path("/home/yuxiaoc/.unified_data_cache") ++ env_metadata = EnvMetadata("drivesim", ++ data_dir, ++ dt, ++ parts=[("train",),("main",)], ++ scene_split_map=defaultdict(lambda: "train"), ++ map_locations=("main",)) ++ env_cache = EnvCache(cache_path) ++ scene_records = list() ++ ++ agent_initial_state = [(np.array([-565,-1001,3.0,0]),"accelerate"), ++ (np.array([-573,-1001,3.0,0]),"accelerate"), ++ (np.array([-573,-1005,3.0,0.0]),"accelerate"), ++ (np.array([-530.0,-976,0.0,-0.95*np.pi/2]),"constvel"), ++ (np.array([-526.4,-976,0.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-507.8,-1027,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-507.8,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-504,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-500.5,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-480.5,-992,0.0,np.pi]),"brake"), ++ (np.array([-473,-992,0.0,np.pi]),"brake"), ++ (np.array([-489,-1049,8.0,np.pi/2-0.1]),"constvel"), ++ (np.array([-488.4,-983,5.0,np.pi*0.75]),"accelerate"), ++ (np.array([-524.2,-964.6,2.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-523.3,-976,0.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-519.5,-975,0.0,-0.92*np.pi/2]),"brake"), ++ (np.array([-563.4,-1011,8.0,-0.18*np.pi/2]),"constvel"), ++ ] ++ ++ noise_spec = [1,1,1,2,2,3,3,3] ++ for r in range(repeat): ++ group1_xn =np.random.randn()*5 ++ group1_yn = np.random.randn()*0.5 ++ group1_vn = np.random.randn()*0.5 ++ group1_psin = np.random.randn()*0.01 ++ group2_xn = np.random.randn()*0.2 ++ group2_yn = np.random.randn()*0.2 ++ group2_vn = 0 ++ group2_psin = np.random.randn()*0.01 ++ group3_xn = np.random.randn()*0.2 ++ group3_yn = np.random.randn()*0.2 ++ group3_vn = 0 ++ group3_psin = np.random.randn()*0.01 ++ group1_n = np.array([group1_xn,group1_yn,group1_vn,group1_psin]) ++ group2_n = np.array([group2_xn,group2_yn,group2_vn,group2_psin]) ++ group3_n = np.array([group3_xn,group3_yn,group3_vn,group3_psin]) ++ init_state=np.stack([x0 for x0,_ in agent_initial_state]) ++ hist_types = [htype for _,htype in agent_initial_state] ++ for i in range(len(init_state)): ++ if i None: ++ super().__init__(methodName) ++ ++ data_source = "nusc_mini" ++ history_sec = 2.0 ++ prediction_sec = 6.0 ++ ++ attention_radius = defaultdict( ++ lambda: 20.0 ++ ) # Default range is 20m unless otherwise specified. ++ attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 ++ attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 ++ attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 ++ attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 ++ ++ self._map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} ++ ++ self._scene_dataset = UnifiedDataset( ++ centric="scene", ++ desired_data=[data_source], ++ history_sec=(history_sec, history_sec), ++ future_sec=(prediction_sec, prediction_sec), ++ agent_interaction_distances=attention_radius, ++ incl_robot_future=False, ++ incl_raster_map=True, ++ raster_map_params=self._map_params, ++ only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ++ no_types=[AgentType.UNKNOWN], ++ num_workers=0, ++ standardize_data=True, ++ data_dirs={ ++ "nusc_mini": "~/datasets/nuScenes", ++ }, ++ ) ++ ++ self._scene_dataloader = DataLoader( ++ self._scene_dataset, ++ batch_size=4, ++ shuffle=False, ++ collate_fn=self._scene_dataset.get_collate_fn(), ++ num_workers=0, ++ ) ++ ++ def _assert_allclose_with_nans(self, tensor1, tensor2, atol=1e-8): ++ """ ++ asserts that the two tensors have nans in the same locations, and the non-nan ++ elements all are close. ++ """ ++ # Check nans are in the same place ++ self.assertFalse( ++ torch.any( # True if there's any mismatch ++ torch.logical_xor( # True where either tensor1 or tensor 2 has nans, but not both (mismatch) ++ torch.isnan(tensor1), # True where tensor1 has nans ++ torch.isnan(tensor2), # True where tensor2 has nans ++ ) ++ ), ++ msg="Nans occur in different places.", ++ ) ++ valid_mask = torch.logical_not(torch.isnan(tensor1)) ++ self.assertTrue( ++ torch.allclose(tensor1[valid_mask], tensor2[valid_mask], atol=atol), ++ msg="Non-nan values don't match.", ++ ) ++ ++ def test_map_transform_scenebatch(self): ++ scene_batch: SceneBatch ++ for i, scene_batch in enumerate(self._scene_dataloader): ++ ++ # Make the tf double for more accurate transform. ++ scene_batch.centered_world_from_agent_tf = scene_batch.centered_world_from_agent_tf.double() ++ ++ maps, maps_resolution, raster_from_world_tf = get_raster_maps_for_scene_batch( ++ scene_batch, self._scene_dataset.cache_path, "nusc_mini", self._map_params) ++ ++ self._assert_allclose_with_nans(scene_batch.rasters_from_world_tf, raster_from_world_tf, atol=1e-2) ++ self._assert_allclose_with_nans(scene_batch.maps_resolution, maps_resolution) ++ self._assert_allclose_with_nans(scene_batch.maps, maps, atol=1e-4) ++ ++ if i > 50: ++ break ++ ++if __name__ == "__main__": ++ unittest.main(catchbreak=False) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f55ac42 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +matplotlib==3.3.4 +seaborn==0.11.1 +# numpy +# scipy +# scikit-learn==0.24.1 +# torch==1.10.2 +pyquaternion==0.9.9 +pytest==6.2.2 +orjson==3.5.1 +ncls==0.0.57 +dill==0.3.5.1 +pathos==0.2.9 +tqdm>=4.53.0 +notebook==6.2.0 +nuscenes-devkit==1.1.5 +pykalman==0.9.5 +# You will likely need to source deactivate and activate +# again to get the correct tensorboard version used. +gym>=0.18.3 +stable-baselines3==1.1.0a11 # pip install git+https://github.com/DLR-RM/stable-baselines3 +requests==2.28.2 +pympler==1.0.1 +moviepy==1.0.3 +wandb +imageio==2.9.0 +ipdb +l5kit==1.5.0 +networkx + +community +pytorch-lightning==2.0.5 +# You will likely need to uninstall opencv-python and install an earlier version instead because of a compatibility issue +# https://github.com/opencv/opencv-python/issues/591 +# pip uninstall opencv-python -y +# pip install "opencv-python-headless<4.3" +hydra-core==1.3.2 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..46bd280 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup, find_packages + +# read the contents of your README file +from os import path + +this_directory = path.abspath(path.dirname(__file__)) + + +# remove images from README + +setup( + name="diffstack", + packages=[ + package for package in find_packages() if package.startswith("diffstack") + ], + install_requires=[ + "wandb", + "pytorch-lightning", + ], + eager_resources=["*"], + include_package_data=True, + python_requires=">=3", + description="diffstack", + author="NVIDIA AV Research", + author_email="", + version="0.0.1", + long_description="", + long_description_content_type="text/markdown", +) From c4f87093cc0219d2342272613f03bbc2c5afcd52 Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Wed, 29 Nov 2023 23:06:26 -0800 Subject: [PATCH 02/10] typo fix --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 154510b..337b48d 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,6 @@ Then add Pplan ``` git clone --recurse-submodules git@github.com:NVlabs/spline-planner.git - ``` You can also sync submodules later using @@ -70,8 +69,10 @@ These additional steps might be necessary # need to reinstall pathos, gets replaced by multiprocessing install pip uninstall pathos -y pip install pathos==0.2.9 +``` + -# Sometimes you need to reinstall matplotlib with the correct version +Sometimes you need to reinstall matplotlib with the correct version ``` pip install matplotlib==3.3.4 From 9bcf0508647ac2ba061eea3643f630d3ca3fca5e Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Wed, 29 Nov 2023 23:09:26 -0800 Subject: [PATCH 03/10] fix typo --- README.md | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 337b48d..324458d 100644 --- a/README.md +++ b/README.md @@ -64,31 +64,51 @@ pip install -e ./trajdata pip install -e ./spline-planner ``` + + These additional steps might be necessary ``` # need to reinstall pathos, gets replaced by multiprocessing install pip uninstall pathos -y pip install pathos==0.2.9 -``` - -Sometimes you need to reinstall matplotlib with the correct version - -``` -pip install matplotlib==3.3.4 -``` +# Gpu affinity for cpu-gpu assignments on NGC (optional) +pip install git+https://gitlab-master.nvidia.com/dl/gwe/gpu_affinity # On Mac sometimes we need to reinstall torch conda install pytorch torchvision torchaudio -c pytorch +# The default requirements installs jax for cpu only. To enable jax with GPU, see https://github.com/google/jax#installation + +# networkx package is not well aligned between py3.8 and py3.9, if you encounter an error for unknown module of gcd in fraction, manually modify that line of code for networkx. It should be located in the site-package/networkx/algorithms/dag.py:23 +# from fractions import gcd +from math import gcd +# The version of numpy might be messed by merging diffstack + mm3d, try restore numpy version for mm3d if the test script does not work after install diffstack requirements +pip uninstall numpy +pip install numpy==1.23.5 + +# The version of bokeh might be messed if you see errors for 'module not found' when using nuplan, update bokeh version to bokeh==2.4.3 +pip install bokeh==2.4.3 + +pip uninstall pygeos + +# To parse CARLA OpenDrive maps manually install extra trajdata dependencties (can be removed once trajdata is updated) +pip install intervaltree bokeh==2.4.3 geopandas selenium +pip install -e ./trajdata/src/trajdata/dataset_specific/opendrive/custom_imap # Fix opencv compatibility issue https://github.com/opencv/opencv-python/issues/591 pip uninstall opencv-python opencv-python-headless -y pip install "opencv-python-headless==4.2.0.34" # pip install "opencv-python-headless==4.7.0.72" # for python 3.9 + +# Sometimes you need to reinstall matplotlib with the correct version + +pip install matplotlib==3.3.4 + ``` + ### Data CTT uses [trajdata](https://github.com/NVlabs/trajdata) as the dataloader, technically, you can train with any dataset supported by trajdata. Considering the vectorized map support, we have tested CTT with WOMD, nuScenes, and nuPlan. From 497113cefe1c20712718641abd547869e950f2ee Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Thu, 30 Nov 2023 11:25:15 -0800 Subject: [PATCH 04/10] rename config json --- ...AFPredStack.json => AFPredStack_nusc.json} | 0 ...TPredStack.json => CTTPredStack_nusc.json} | 0 config/templates/SceneFormerPredStack.json | 358 ------------------ .../SceneFormerPredStack_nuplan.json | 358 ------------------ 4 files changed, 716 deletions(-) rename config/templates/{AFPredStack.json => AFPredStack_nusc.json} (100%) rename config/templates/{CTTPredStack.json => CTTPredStack_nusc.json} (100%) delete mode 100644 config/templates/SceneFormerPredStack.json delete mode 100644 config/templates/SceneFormerPredStack_nuplan.json diff --git a/config/templates/AFPredStack.json b/config/templates/AFPredStack_nusc.json similarity index 100% rename from config/templates/AFPredStack.json rename to config/templates/AFPredStack_nusc.json diff --git a/config/templates/CTTPredStack.json b/config/templates/CTTPredStack_nusc.json similarity index 100% rename from config/templates/CTTPredStack.json rename to config/templates/CTTPredStack_nusc.json diff --git a/config/templates/SceneFormerPredStack.json b/config/templates/SceneFormerPredStack.json deleted file mode 100644 index eade21a..0000000 --- a/config/templates/SceneFormerPredStack.json +++ /dev/null @@ -1,358 +0,0 @@ -{ - "registered_name": "SceneFormerPredStack", - "train": { - "logging": { - "terminal_output_to_txt": true, - "log_tb": false, - "log_wandb": true, - "wandb_project_name": "diffstack", - "log_every_n_steps": 10, - "flush_every_n_steps": 100 - }, - "save": { - "enabled": true, - "every_n_steps": 500, - "best_k": 10, - "save_best_rollout": false, - "save_best_validation": true - }, - "rollout": { - "save_video": true, - "enabled": false, - "every_n_steps": 5000, - "warm_start_n_steps": 1 - }, - "training": { - "batch_size": 24, - "num_steps": 1000000, - "num_data_workers": 64 - }, - "validation": { - "enabled": true, - "batch_size": 24, - "num_data_workers": 64, - "every_n_steps": 400, - "num_steps_per_epoch": 20 - }, - "parallel_strategy": "ddp", - "rebuild_cache": false, - "on_ngc": false, - "trajdata_source_train": "train", - "trajdata_source_valid": "val", - "trajdata_source_root": "nusc_trainval", - "trajdata_val_source_root": null, - "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", - "datamodule_class": "UnifiedDataModule", - "ego_only": true, - "amp": true, - "auto_batch_size": false, - "max_batch_size": 36, - "gradient_clip_val": 0.5 - }, - "env": { - "name": "nusc_trainval", - "rasterizer": { - "raster_size": 224, - "pixel_size": 0.5, - "ego_center": [ - -0.75, - 0.0 - ] - }, - "data_generation_params": { - "other_agents_num": 11, - "max_agents_distance": 40, - "yaw_correction_speed": 1.0 - }, - "simulation": { - "distance_th_close": 30, - "num_simulation_steps": 50, - "start_frame_index": 0, - "vectorize_lane": "ego" - }, - "incl_neighbor_map": false, - "incl_vector_map": true, - "incl_raster_map": false, - "calc_lane_graph": true, - "max_num_lanes": 32, - "num_lane_pts": 32, - "remove_single_successor": true - }, - "stack": { - "predictor": { - "name": "sceneformer", - "step_time": 0.25, - "history_num_frames": 6, - "future_num_frames": 12, - "n_embd": 128, - "n_head": 4, - "PE_mode": "PE", - "use_rpe_net": false, - "enc_nblock": 2, - "dec_nblock": 2, - "edge_dim": { - "a2a": 14, - "a2l": 12, - "l2a": 12, - "l2l": 16 - }, - "a2l_edge_type": "proj", - "a2l_n_embd": 64, - "attn_ntype": { - "a2a": 2, - "a2l": 1, - "l2l": 2 - }, - "lane_GNN_num_layers": 4, - "homotopy_GNN_num_layers": 4, - "closed_loop": false, - "CL_Tf_mode": 6, - "CL_step": 2, - "encoder": { - "attn_pdrop": 0.05, - "resid_pdrop": 0.05, - "pooling": "attn", - "edge_scale": 20, - "edge_clip": [ - -4, - 4 - ], - "mode_embed_dim": 64, - "jm_GNN_nblock": 2, - "num_joint_samples": 30, - "num_joint_factors": 7, - "null_lane_mode": true - }, - "decoder": { - "arch": "mlp", - "lstm_hidden_size": 128, - "mlp_hidden_dims": [ - 128, - 256 - ], - "traj_dim": 4, - "num_layers": 2, - "dyn": { - "vehicle": "unicycle", - "pedestrian": "DI_unicycle" - }, - "attn_pdrop": 0.05, - "resid_pdrop": 0.05, - "pooling": "attn", - "decode_num_modes": 3, - "AR_step_size": 1, - "GNN_enabled": false, - "AR_update_mode": "step", - "dec_rounds": 5 - }, - "num_lane_pts": 32, - "hist_lane_relation": "LaneRelation", - "fut_lane_relation": "SimpleLaneRelation", - "classify_a2l_4all_lanes": false, - "max_joint_cardinality": 5, - "loss_weights": { - "marginal_lm_loss": 5.0, - "marginal_homo_loss": 5.0, - "joint_prob_loss": 5.0, - "xy_loss": 2.0, - "heading_loss": 1.0, - "l2_reg": 0.0001, - "lm_consistency_loss": 5.0, - "homotopy_consistency_loss": 5.0, - "coll_loss": 2.0, - "acce_reg_loss": 0.05, - "steering_reg_loss": 0.2, - "input_violation_loss": 20.0, - "jerk_loss": 0.1 - }, - "loss": { - "lm_margin_offset": 0.2 - }, - "weighted_consistency_loss": false, - "LR_sample_hack": true, - "scene_centric": true, - "optim_params": { - "policy": { - "learning_rate": { - "initial": 0.0001, - "decay_factor": 0.05, - "epoch_schedule": [] - }, - "regularization": { - "L2": 0.0 - } - } - } - }, - "name": "pred" - }, - "eval": { - "name": null, - "env": "nusc", - "dataset_path": null, - "eval_class": "", - "seed": 0, - "num_scenes_per_batch": 4, - "num_scenes_to_evaluate": 100, - "num_episode_repeats": 3, - "start_frame_index_each_episode": null, - "seed_each_episode": null, - "ego_only": false, - "agent_eval_class": null, - "ckpt_root_dir": "checkpoints/", - "experience_hdf5_path": null, - "results_dir": "results/", - "ckpt": { - "policy": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "planner": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "predictor": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "cvae_metric": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "occupancy_metric": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - } - }, - "policy": { - "mask_drivable": true, - "num_plan_samples": 50, - "num_action_samples": 10, - "pos_to_yaw": true, - "yaw_correction_speed": 1.0, - "diversification_clearance": null, - "sample": false, - "cost_weights": { - "collision_weight": 10.0, - "lane_weight": 1.0, - "likelihood_weight": 0.0, - "progress_weight": 0.0 - } - }, - "metrics": { - "compute_analytical_metrics": true, - "compute_learned_metrics": false - }, - "perturb": { - "enabled": false, - "OU": { - "theta": 0.8, - "sigma": [ - 0.0, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 4.0 - ], - "scale": [ - 1.0, - 1.0, - 0.2 - ] - } - }, - "rolling_perturb": { - "enabled": false, - "OU": { - "theta": 0.8, - "sigma": 0.5, - "scale": [ - 1.0, - 1.0, - 0.2 - ] - } - }, - "occupancy": { - "rolling": true, - "rolling_horizon": [ - 5, - 10, - 20 - ] - }, - "cvae": { - "rolling": true, - "rolling_horizon": [ - 5, - 10, - 20 - ] - }, - "nusc": { - "eval_scenes": [ - 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90 - ], - "n_step_action": 5, - "num_simulation_steps": 200, - "skip_first_n": 0 - }, - "l5kit": { - "eval_scenes": [ - 9058, - 5232, - 14153, - 8173, - 10314, - 7027, - 9812, - 1090, - 9453, - 978, - 10263, - 874, - 5563, - 9613, - 261, - 2826, - 2175, - 9977, - 6423, - 1069 - ], - "n_step_action": 5, - "num_simulation_steps": 200, - "skip_first_n": 1, - "skimp_rollout": false - }, - "adjustment": { - "random_init_plan": true, - "remove_existing_neighbors": false, - "initial_num_neighbors": 4, - "num_frame_per_new_agent": 20 - } - }, - "stack_type": "pred", - "name": "test", - "root_dir": "predictor_sceneformer_trained_models/", - "seed": 1, - "devices": { - "num_gpus": 1 - } -} \ No newline at end of file diff --git a/config/templates/SceneFormerPredStack_nuplan.json b/config/templates/SceneFormerPredStack_nuplan.json deleted file mode 100644 index f1c3337..0000000 --- a/config/templates/SceneFormerPredStack_nuplan.json +++ /dev/null @@ -1,358 +0,0 @@ -{ - "registered_name": "SceneFormerPredStack", - "train": { - "logging": { - "terminal_output_to_txt": true, - "log_tb": false, - "log_wandb": true, - "wandb_project_name": "diffstack", - "log_every_n_steps": 10, - "flush_every_n_steps": 100 - }, - "save": { - "enabled": true, - "every_n_steps": 3000, - "best_k": 10, - "save_best_rollout": false, - "save_best_validation": true - }, - "rollout": { - "save_video": true, - "enabled": false, - "every_n_steps": 5000, - "warm_start_n_steps": 1 - }, - "training": { - "batch_size": 20, - "num_steps": 1000000, - "num_data_workers": 64 - }, - "validation": { - "enabled": true, - "batch_size": 20, - "num_data_workers": 64, - "every_n_steps": 1500, - "num_steps_per_epoch": 200 - }, - "parallel_strategy": "ddp", - "rebuild_cache": false, - "on_ngc": false, - "trajdata_source_train": "train", - "trajdata_source_valid": "val", - "trajdata_source_root": "nuplan_train", - "trajdata_val_source_root": "nuplan_val", - "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", - "datamodule_class": "UnifiedDataModule", - "ego_only": true, - "amp": true, - "auto_batch_size": false, - "max_batch_size": 36, - "gradient_clip_val": 0.5 - }, - "env": { - "name": "nusc_trainval", - "rasterizer": { - "raster_size": 224, - "pixel_size": 0.5, - "ego_center": [ - -0.75, - 0.0 - ] - }, - "data_generation_params": { - "other_agents_num": 11, - "max_agents_distance": 40, - "yaw_correction_speed": 1.0 - }, - "simulation": { - "distance_th_close": 30, - "num_simulation_steps": 50, - "start_frame_index": 0, - "vectorize_lane": "ego" - }, - "incl_neighbor_map": false, - "incl_vector_map": true, - "incl_raster_map": false, - "calc_lane_graph": true, - "max_num_lanes": 32, - "num_lane_pts": 32, - "remove_single_successor": true - }, - "stack": { - "predictor": { - "name": "sceneformer", - "step_time": 0.25, - "history_num_frames": 6, - "future_num_frames": 12, - "n_embd": 128, - "n_head": 4, - "PE_mode": "PE", - "use_rpe_net": false, - "enc_nblock": 2, - "dec_nblock": 2, - "edge_dim": { - "a2a": 14, - "a2l": 12, - "l2a": 12, - "l2l": 16 - }, - "a2l_edge_type": "proj", - "a2l_n_embd": 64, - "attn_ntype": { - "a2a": 2, - "a2l": 1, - "l2l": 2 - }, - "lane_GNN_num_layers": 4, - "homotopy_GNN_num_layers": 4, - "closed_loop": false, - "CL_Tf_mode": 6, - "CL_step": 2, - "encoder": { - "attn_pdrop": 0.05, - "resid_pdrop": 0.05, - "pooling": "attn", - "edge_scale": 20, - "edge_clip": [ - -4, - 4 - ], - "mode_embed_dim": 64, - "jm_GNN_nblock": 2, - "num_joint_samples": 30, - "num_joint_factors": 7, - "null_lane_mode": true - }, - "decoder": { - "arch": "mlp", - "lstm_hidden_size": 128, - "mlp_hidden_dims": [ - 128, - 256 - ], - "traj_dim": 4, - "num_layers": 2, - "dyn": { - "vehicle": "unicycle", - "pedestrian": "DI_unicycle" - }, - "attn_pdrop": 0.05, - "resid_pdrop": 0.05, - "pooling": "attn", - "decode_num_modes": 3, - "AR_step_size": 1, - "GNN_enabled": false, - "AR_update_mode": "step", - "dec_rounds": 5 - }, - "num_lane_pts": 32, - "hist_lane_relation": "LaneRelation", - "fut_lane_relation": "SimpleLaneRelation", - "classify_a2l_4all_lanes": false, - "max_joint_cardinality": 5, - "loss_weights": { - "marginal_lm_loss": 5.0, - "marginal_homo_loss": 5.0, - "joint_prob_loss": 5.0, - "xy_loss": 4.0, - "heading_loss": 1.0, - "l2_reg": 0.0001, - "lm_consistency_loss": 5.0, - "homotopy_consistency_loss": 5.0, - "coll_loss": 2.0, - "acce_reg_loss": 0.05, - "steering_reg_loss": 0.2, - "input_violation_loss": 20.0, - "jerk_loss": 0.1 - }, - "loss": { - "lm_margin_offset": 0.2 - }, - "weighted_consistency_loss": false, - "LR_sample_hack": true, - "scene_centric": true, - "optim_params": { - "policy": { - "learning_rate": { - "initial": 0.0001, - "decay_factor": 0.05, - "epoch_schedule": [] - }, - "regularization": { - "L2": 0.0 - } - } - } - }, - "name": "pred" - }, - "eval": { - "name": null, - "env": "nusc", - "dataset_path": null, - "eval_class": "", - "seed": 0, - "num_scenes_per_batch": 4, - "num_scenes_to_evaluate": 100, - "num_episode_repeats": 3, - "start_frame_index_each_episode": null, - "seed_each_episode": null, - "ego_only": false, - "agent_eval_class": null, - "ckpt_root_dir": "checkpoints/", - "experience_hdf5_path": null, - "results_dir": "results/", - "ckpt": { - "policy": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "planner": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "predictor": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "cvae_metric": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "occupancy_metric": { - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - } - }, - "policy": { - "mask_drivable": true, - "num_plan_samples": 50, - "num_action_samples": 10, - "pos_to_yaw": true, - "yaw_correction_speed": 1.0, - "diversification_clearance": null, - "sample": false, - "cost_weights": { - "collision_weight": 10.0, - "lane_weight": 1.0, - "likelihood_weight": 0.0, - "progress_weight": 0.0 - } - }, - "metrics": { - "compute_analytical_metrics": true, - "compute_learned_metrics": false - }, - "perturb": { - "enabled": false, - "OU": { - "theta": 0.8, - "sigma": [ - 0.0, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 4.0 - ], - "scale": [ - 1.0, - 1.0, - 0.2 - ] - } - }, - "rolling_perturb": { - "enabled": false, - "OU": { - "theta": 0.8, - "sigma": 0.5, - "scale": [ - 1.0, - 1.0, - 0.2 - ] - } - }, - "occupancy": { - "rolling": true, - "rolling_horizon": [ - 5, - 10, - 20 - ] - }, - "cvae": { - "rolling": true, - "rolling_horizon": [ - 5, - 10, - 20 - ] - }, - "nusc": { - "eval_scenes": [ - 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90 - ], - "n_step_action": 5, - "num_simulation_steps": 200, - "skip_first_n": 0 - }, - "l5kit": { - "eval_scenes": [ - 9058, - 5232, - 14153, - 8173, - 10314, - 7027, - 9812, - 1090, - 9453, - 978, - 10263, - 874, - 5563, - 9613, - 261, - 2826, - 2175, - 9977, - 6423, - 1069 - ], - "n_step_action": 5, - "num_simulation_steps": 200, - "skip_first_n": 1, - "skimp_rollout": false - }, - "adjustment": { - "random_init_plan": true, - "remove_existing_neighbors": false, - "initial_num_neighbors": 4, - "num_frame_per_new_agent": 20 - } - }, - "stack_type": "pred", - "name": "test", - "root_dir": "predictor_sceneformer_trained_models/", - "seed": 1, - "devices": { - "num_gpus": 4 - } -} \ No newline at end of file From cfb693001d94004891792eabe56b62ca238dc827 Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Thu, 30 Nov 2023 11:27:24 -0800 Subject: [PATCH 05/10] fix config name --- README.md | 2 ++ config/templates/{AFPredStack_nusc.json => AFPredStack.json} | 0 config/templates/{CTTPredStack_nusc.json => CTTPredStack.json} | 0 3 files changed, 2 insertions(+) rename config/templates/{AFPredStack_nusc.json => AFPredStack.json} (100%) rename config/templates/{CTTPredStack_nusc.json => CTTPredStack.json} (100%) diff --git a/README.md b/README.md index 324458d..1ea4468 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,8 @@ python diffstack/scripts/generate_config_templates.py ## Training and eval +The following examples use nuScenes trainval as dataset, you'll need to prepare the nuScenes dataset following instructions in [trajdata](https://github.com/NVlabs/trajdata). + Training script: ``` diff --git a/config/templates/AFPredStack_nusc.json b/config/templates/AFPredStack.json similarity index 100% rename from config/templates/AFPredStack_nusc.json rename to config/templates/AFPredStack.json diff --git a/config/templates/CTTPredStack_nusc.json b/config/templates/CTTPredStack.json similarity index 100% rename from config/templates/CTTPredStack_nusc.json rename to config/templates/CTTPredStack.json From f84dd23bd4f927ab348ff4a35c2f77da087d5564 Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Thu, 30 Nov 2023 23:39:39 -0800 Subject: [PATCH 06/10] update readme --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ea4468..c0a9660 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Impements Categorical Traffic Transformer in the environment of diffstack. -Paper [pdf](link) +Paper [pdf](https://arxiv.org/abs/2311.18307) ## Setup @@ -150,4 +150,10 @@ python diffstack/scripts/train_pl.py Training and eval example commands are also included in the `.vscode/launch.json` file. +## Trained models +| Training dataset | dt | Th | Tf | config | checkpoint | +|------------------|-------|------|----|--------|------------| +| nuScenes | 0.25s | 1.5s | 3s | [config](https://drive.google.com/file/d/1fnPX0o2qPVGszFxbYX_LDSYJ221IUFk7/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1KvTdJQIEtk50cwiUzMFdl-ZtxpKg52kK/view?usp=drive_link) | +| nuPlan | 0.25s | 1.5s | 3s |[config](https://drive.google.com/file/d/1huNKKlTeT_i3oMOgPL6L1iUT2hKouKtt/view?usp=drive_link) |[ckpt](https://drive.google.com/file/d/1w66sf6sTaoLI-Rl6MFHpuegu0y-i5R0y/view?usp=drive_link) | +| WOMD | 0.2s | 1s | 8s | [config](https://drive.google.com/file/d/1QgsHm3UhY74245YbhsQ4GlTyxyOBpE5y/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1qClV16V8jlSMMuPoAavFQeF7qJlb71GV/view?usp=drive_link) | From 227f472ce62f80d7ded4091efea29d72c5606ebb Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Thu, 30 Nov 2023 23:41:01 -0800 Subject: [PATCH 07/10] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c0a9660..4910c3c 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,7 @@ Training and eval example commands are also included in the `.vscode/launch.json ## Trained models -| Training dataset | dt | Th | Tf | config | checkpoint | +| Training dataset | Step time | History horizon | Future horizon | config | checkpoint | |------------------|-------|------|----|--------|------------| | nuScenes | 0.25s | 1.5s | 3s | [config](https://drive.google.com/file/d/1fnPX0o2qPVGszFxbYX_LDSYJ221IUFk7/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1KvTdJQIEtk50cwiUzMFdl-ZtxpKg52kK/view?usp=drive_link) | | nuPlan | 0.25s | 1.5s | 3s |[config](https://drive.google.com/file/d/1huNKKlTeT_i3oMOgPL6L1iUT2hKouKtt/view?usp=drive_link) |[ckpt](https://drive.google.com/file/d/1w66sf6sTaoLI-Rl6MFHpuegu0y-i5R0y/view?usp=drive_link) | From a6b519469823b6643885570d8451adf386b9822e Mon Sep 17 00:00:00 2001 From: chenyx09 Date: Thu, 30 Nov 2023 23:43:10 -0800 Subject: [PATCH 08/10] fix readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 4910c3c..5127fdc 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,8 @@ Training and eval example commands are also included in the `.vscode/launch.json ## Trained models + + | Training dataset | Step time | History horizon | Future horizon | config | checkpoint | |------------------|-------|------|----|--------|------------| | nuScenes | 0.25s | 1.5s | 3s | [config](https://drive.google.com/file/d/1fnPX0o2qPVGszFxbYX_LDSYJ221IUFk7/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1KvTdJQIEtk50cwiUzMFdl-ZtxpKg52kK/view?usp=drive_link) | From c5e886ab27004ddada0403ab2c3ef54250706f95 Mon Sep 17 00:00:00 2001 From: yuxiaoc Date: Fri, 1 Dec 2023 19:34:58 -0800 Subject: [PATCH 09/10] remove AF, update readme --- .vscode/launch.json | 37 - LICENSE.txt | 64 + README.md | 41 +- config/templates/AFPredStack.json | 260 -- diffstack/configs/algo_config.py | 97 - diffstack/configs/registry.py | 10 - diffstack/models/agentformer.py | 3324 ----------------- diffstack/models/agentformer_lib.py | 1044 ------ diffstack/modules/predictors/factory.py | 12 +- .../modules/predictors/tbsim_predictors.py | 477 --- 10 files changed, 82 insertions(+), 5284 deletions(-) create mode 100644 LICENSE.txt delete mode 100644 config/templates/AFPredStack.json delete mode 100644 diffstack/models/agentformer.py delete mode 100644 diffstack/models/agentformer_lib.py delete mode 100644 diffstack/modules/predictors/tbsim_predictors.py diff --git a/.vscode/launch.json b/.vscode/launch.json index b96296a..78846e8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,23 +16,6 @@ "--dataset_path=", ] }, - { - "name": "PL train agentformer", - "type": "python", - "request": "launch", - "program": "diffstack/scripts/train_pl.py", - "console": "integratedTerminal", - "justMyCode": true, - "env": { - "PYTHONPATH": "${workspaceFolder}${pathSeparator}${env:PYTHONPATH}", - }, - "args": [ - "--config_file=${workspaceFolder}/config/templates/AFPredStack.json", - "--remove_exp_dir", - // "--debug", - "--dataset_path=", - ] - }, { "name": "PL eval CTT", "type": "python", @@ -53,25 +36,5 @@ "--dataset_path=", ] }, - { - "name": "PL eval agentformer", - "type": "python", - "request": "launch", - "program": "diffstack/scripts/train_pl.py", - "console": "integratedTerminal", - "justMyCode": true, - "args": [ - "--config_file=", - "--ckpt_path=", - "--test_data=", - "--test_data_root=", - "--evaluate", - "--log_image_frequency=10", - "--eval_output_dir=", - "--test_batch_size=16", - // "--log_all_image", - "--dataset_path=", - ] - } ] } diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..c772d05 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,64 @@ +NVIDIA Source Code License for DiffStack + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. + +“Software” means the original work of authorship made available under this License. + +“Work” means the Software and any additions to or derivative works of the Software that are made available under +this License. + +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under +U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include +works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are “made available” under this License by including in or with the Work either +(a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, +worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly +display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you +include a complete copy of this License with your distribution, and (c) you retain without modification any +copyright, patent, trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and +distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use +limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works +that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution +requirements in Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use +non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative +works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, +cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then +your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, +or trademarks, except as necessary to reproduce the notices described in this License. + +3.6 Termination. If you violate any term of this License, then your rights under this License (including the +grant in Section 2.1) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU +BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING +NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR +INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR +DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN +ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. \ No newline at end of file diff --git a/README.md b/README.md index 5127fdc..c9bad7e 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ For trajdata, we need to use branch `vectorize`, there are two options: ``` git clone --recurse-submodules --branch main git@github.com:NVlabs/trajdata.git; cd trajdata; +git fetch origin +git reset --hard 748b8b1 git apply ../patches/trajdata_vectorize.patch cd .. ``` @@ -72,30 +74,6 @@ These additional steps might be necessary pip uninstall pathos -y pip install pathos==0.2.9 -# Gpu affinity for cpu-gpu assignments on NGC (optional) -pip install git+https://gitlab-master.nvidia.com/dl/gwe/gpu_affinity - -# On Mac sometimes we need to reinstall torch -conda install pytorch torchvision torchaudio -c pytorch - -# The default requirements installs jax for cpu only. To enable jax with GPU, see https://github.com/google/jax#installation - -# networkx package is not well aligned between py3.8 and py3.9, if you encounter an error for unknown module of gcd in fraction, manually modify that line of code for networkx. It should be located in the site-package/networkx/algorithms/dag.py:23 -# from fractions import gcd -from math import gcd - -# The version of numpy might be messed by merging diffstack + mm3d, try restore numpy version for mm3d if the test script does not work after install diffstack requirements -pip uninstall numpy -pip install numpy==1.23.5 - -# The version of bokeh might be messed if you see errors for 'module not found' when using nuplan, update bokeh version to bokeh==2.4.3 -pip install bokeh==2.4.3 - -pip uninstall pygeos - -# To parse CARLA OpenDrive maps manually install extra trajdata dependencties (can be removed once trajdata is updated) -pip install intervaltree bokeh==2.4.3 geopandas selenium -pip install -e ./trajdata/src/trajdata/dataset_specific/opendrive/custom_imap # Fix opencv compatibility issue https://github.com/opencv/opencv-python/issues/591 pip uninstall opencv-python opencv-python-headless -y @@ -108,6 +86,21 @@ pip install matplotlib==3.3.4 ``` +### Key files and code structure + +Diffstack uses a similar config system as [TBSIM](https://github.com/NVlabs/traffic-behavior-simulation), where the config templates are first defined in python inside the [diffstack/configs](/diffstack/configs/) folder. We separate the configs for [data](/diffstack/configs/trajdata_config.py), [training](/diffstack/configs/base.py), and [models](/diffstack/configs/algo_config.py). + +The training and evaluation process takes in a JSON file as config, and one can call the [generate_config_templates.py](/diffstack/scripts/generate_config_templates.py) to generate all the template JSON configs, stored in [config/templates](/config/templates/) folder, by taking the default values from the python config files. + +The models are separetely defined in the [models](/diffstack/models/) folder and [modules](/diffstack/modules/) folder where the former defines the model architecture, the latter wraps the torch model in a unified format called module, defined in [diffstack/modules/module.py](/diffstack/modules/module.py). + +Modules can be chained together to form [stacks](/diffstack/stacks/), which can be trained/evalulated as a whole. For this codebase, we only include CTT, thus the only type of stack is a prediction stack. + +A stack is wrapped as a Pytorch-lightning model for training and evaluation, see [train_pl.py](/diffstack/scripts/train_pl.py) for details. + +The main files of CTT to look for is the [model file](/diffstack/models/CTT.py), and the [module file](/diffstack/modules/predictors/CTT.py). + +We also included a rich collection of [utils functions](/diffstack/utils/), among which many are not used by CTT, but we believe they contribute to creating a convenient code base. ### Data diff --git a/config/templates/AFPredStack.json b/config/templates/AFPredStack.json deleted file mode 100644 index 7a5a731..0000000 --- a/config/templates/AFPredStack.json +++ /dev/null @@ -1,260 +0,0 @@ -{ - "registered_name": "AFPredStack", - "train": { - "logging": { - "terminal_output_to_txt": true, - "log_tb": false, - "log_wandb": true, - "wandb_project_name": "diffstack", - "log_every_n_steps": 10, - "flush_every_n_steps": 100 - }, - "save": { - "enabled": true, - "every_n_steps": 1000, - "best_k": 10, - "save_best_rollout": false, - "save_best_validation": true - }, - "rollout": { - "save_video": true, - "enabled": false, - "every_n_steps": 5000, - "warm_start_n_steps": 1 - }, - "training": { - "batch_size": 100, - "num_steps": 200000, - "num_data_workers": 8 - }, - "validation": { - "enabled": true, - "batch_size": 32, - "num_data_workers": 6, - "every_n_steps": 500, - "num_steps_per_epoch": 50 - }, - "parallel_strategy": "ddp", - "rebuild_cache": false, - "on_ngc": false, - "amp": false, - "auto_batch_size": false, - "max_batch_size": 1000, - "gradient_clip_val": 0.5, - "trajdata_source_train": "train", - "trajdata_source_valid": "val", - "trajdata_source_test": null, - "trajdata_source_root": "nusc_trainval", - "trajdata_val_source_root": null, - "trajdata_test_source_root": null, - "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", - "datamodule_class": "UnifiedDataModule", - "ego_only": true, - "test": { - "enabled": false, - "batch_size": 32, - "num_data_workers": 6, - "every_n_steps": 500, - "num_steps_per_epoch": 50 - } - }, - "env": { - "name": "nusc_trainval", - "rasterizer": { - "raster_size": 224, - "pixel_size": 0.5, - "ego_center": [ - -0.75, - 0.0 - ] - }, - "data_generation_params": { - "other_agents_num": 10, - "max_agents_distance": 30, - "yaw_correction_speed": 1.0 - }, - "simulation": { - "distance_th_close": 30, - "num_simulation_steps": 50, - "start_frame_index": 0, - "vectorize_lane": "ego" - }, - "incl_neighbor_map": false, - "incl_vector_map": true, - "incl_raster_map": false, - "remove_single_successor": false, - "calc_lane_graph": true, - "max_num_lanes": 15, - "num_lane_pts": 32, - "remove_parked": false - }, - "stack": { - "predictor": { - "name": "agentformer", - "seed": 1, - "load_map": false, - "dynamic_type": "Unicycle", - "step_time": 0.1, - "history_num_frames": 10, - "future_num_frames": 20, - "traj_scale": 10, - "nz": 32, - "sample_k": 4, - "tf_model_dim": 256, - "tf_ff_dim": 512, - "tf_nhead": 8, - "tf_dropout": 0.1, - "z_tau": { - "start": 0.5, - "finish": 0.0001, - "decay": 0.5 - }, - "input_type": [ - "scene_norm", - "vel", - "heading" - ], - "fut_input_type": [ - "scene_norm", - "vel", - "heading" - ], - "dec_input_type": [ - "heading" - ], - "pred_type": "dynamic", - "sn_out_type": "norm", - "sn_out_heading": false, - "pos_concat": true, - "rand_rot_scene": false, - "use_map": true, - "pooling": "mean", - "agent_enc_shuffle": false, - "vel_heading": false, - "max_agent_len": 128, - "agent_enc_learn": false, - "use_agent_enc": false, - "motion_dim": 2, - "forecast_dim": 2, - "z_type": "gaussian", - "nlayer": 6, - "ar_detach": true, - "pred_scale": 1.0, - "pos_offset": false, - "learn_prior": true, - "discrete_rot": false, - "map_global_rot": false, - "ar_train": true, - "max_train_agent": 100, - "num_eval_samples": 5, - "UAC": false, - "loss_cfg": { - "kld": { - "min_clip": 1.0 - }, - "sample": { - "weight": 1.0, - "k": 20 - } - }, - "loss_weights": { - "prediction_loss": 1.0, - "kl_loss": 1.0, - "collision_loss": 3.0, - "EC_collision_loss": 5.0, - "diversity_loss": 0.3, - "deviation_loss": 0.1 - }, - "scene_orig_all_past": false, - "conn_dist": 100000.0, - "scene_centric": true, - "stage": 2, - "num_frames_per_stage": 10, - "ego_conditioning": true, - "perturb": { - "enabled": true, - "N_pert": 1, - "OU": { - "theta": 0.8, - "sigma": 2.0, - "scale": [ - 1.0, - 0.3 - ] - } - }, - "map_encoder": { - "model_architecture": "resnet18", - "image_shape": [ - 3, - 224, - 224 - ], - "feature_dim": 32, - "spatial_softmax": { - "enabled": false, - "kwargs": { - "num_kp": 32, - "temperature": 1.0, - "learnable_temperature": false - } - } - }, - "context_encoder": { - "nlayer": 2 - }, - "future_decoder": { - "nlayer": 2, - "out_mlp_dim": [ - 512, - 256 - ] - }, - "future_encoder": { - "nlayer": 2 - }, - "optim_params": { - "policy": { - "learning_rate": { - "initial": 0.0001, - "decay_factor": 0.1, - "epoch_schedule": [] - }, - "regularization": { - "L2": 0.0 - } - } - } - }, - "stack_type": "pred" - }, - "eval": { - "name": null, - "env": "nusc", - "dataset_path": null, - "eval_class": "", - "seed": 0, - "ckpt_root_dir": "checkpoints/", - "ckpt": { - "dir": null, - "ngc_job_id": null, - "ckpt_dir": null, - "ckpt_key": null - }, - "eval": { - "batch_size": 100, - "num_steps": null, - "num_data_workers": 8 - }, - "log_image_frequency": null, - "log_all_image": false, - "trajdata_source_root": "nusc_trainval", - "trajdata_source_eval": "val" - }, - "name": "test", - "root_dir": "predictor_agentformer_trained_models/", - "seed": 1, - "devices": { - "num_gpus": 1 - } -} \ No newline at end of file diff --git a/diffstack/configs/algo_config.py b/diffstack/configs/algo_config.py index e26fc71..cf9007e 100644 --- a/diffstack/configs/algo_config.py +++ b/diffstack/configs/algo_config.py @@ -50,103 +50,6 @@ def __init__(self): self.checkpoint.path = None -class AgentFormerConfig(AlgoConfig): - def __init__(self): - super(AgentFormerConfig, self).__init__() - self.name = "agentformer" - self.seed = 1 - self.load_map = False - self.dynamic_type = "Unicycle" - self.step_time = 0.1 - self.history_num_frames = 10 - self.future_num_frames = 20 - self.traj_scale = 10 - self.nz = 32 - self.sample_k = 4 - self.tf_model_dim = 256 - self.tf_ff_dim = 512 - self.tf_nhead = 8 - self.tf_dropout = 0.1 - self.z_tau.start = 0.5 - self.z_tau.finish = 0.0001 - self.z_tau.decay = 0.5 - self.input_type = ["scene_norm", "vel", "heading"] - self.fut_input_type = ["scene_norm", "vel", "heading"] - self.dec_input_type = ["heading"] - self.pred_type = "dynamic" - self.sn_out_type = "norm" - self.sn_out_heading = False - self.pos_concat = True - self.rand_rot_scene = False - self.use_map = True - self.pooling = "mean" - self.agent_enc_shuffle = False - self.vel_heading = False - self.max_agent_len = 128 - self.agent_enc_learn = False - self.use_agent_enc = False - self.motion_dim = 2 - self.forecast_dim = 2 - self.z_type = "gaussian" - self.nlayer = 6 - self.ar_detach = True - self.pred_scale = 1.0 - self.pos_offset = False - self.learn_prior = True - self.discrete_rot = False - self.map_global_rot = False - self.ar_train = True - self.max_train_agent = 100 - self.num_eval_samples = 5 - - self.UAC = False # compare unconditional and conditional prediction - - self.loss_cfg.kld.min_clip = 1.0 - self.loss_cfg.sample.weight = 1.0 - self.loss_cfg.sample.k = 20 - self.loss_weights.prediction_loss = 1.0 - self.loss_weights.kl_loss = 1.0 - self.loss_weights.collision_loss = 3.0 - self.loss_weights.EC_collision_loss = 5.0 - self.loss_weights.diversity_loss = 0.3 - self.loss_weights.deviation_loss = 0.1 - self.scene_orig_all_past = False - self.conn_dist = 100000.0 - self.scene_centric = True - self.stage = 2 - self.num_frames_per_stage = 10 - - self.ego_conditioning = True - self.perturb.enabled = True - self.perturb.N_pert = 1 - self.perturb.OU.theta = 0.8 - self.perturb.OU.sigma = 2.0 - self.perturb.OU.scale = [1.0, 0.3] - - self.map_encoder.model_architecture = "resnet18" - self.map_encoder.image_shape = [3, 224, 224] - self.map_encoder.feature_dim = 32 - self.map_encoder.spatial_softmax.enabled = False - self.map_encoder.spatial_softmax.kwargs.num_kp = 32 - self.map_encoder.spatial_softmax.kwargs.temperature = 1.0 - self.map_encoder.spatial_softmax.kwargs.learnable_temperature = False - - self.context_encoder.nlayer = 2 - - self.future_decoder.nlayer = 2 - self.future_decoder.out_mlp_dim = [512, 256] - self.future_encoder.nlayer = 2 - - self.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate - self.optim_params.policy.learning_rate.decay_factor = ( - 0.1 # factor to decay LR by (if epoch schedule non-empty) - ) - self.optim_params.policy.learning_rate.epoch_schedule = ( - [] - ) # epochs where LR decay occurs - self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength - - class CTTConfig(AlgoConfig): def __init__(self): super(CTTConfig, self).__init__() diff --git a/diffstack/configs/registry.py b/diffstack/configs/registry.py index e3e9ce8..18672b1 100644 --- a/diffstack/configs/registry.py +++ b/diffstack/configs/registry.py @@ -6,7 +6,6 @@ from diffstack.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig from diffstack.configs.algo_config import ( - AgentFormerConfig, CTTConfig, ) @@ -14,15 +13,6 @@ EXP_CONFIG_REGISTRY = dict() -EXP_CONFIG_REGISTRY["AFPredStack"] = ExperimentConfig( - train_config=TrajdataTrainConfig(), - env_config=TrajdataEnvConfig(), - module_configs=Dict(predictor=AgentFormerConfig()), - registered_name="AFPredStack", - stack_type="pred", -) - - EXP_CONFIG_REGISTRY["CTTPredStack"] = ExperimentConfig( train_config=TrajdataTrainConfig(), env_config=TrajdataEnvConfig(), diff --git a/diffstack/models/agentformer.py b/diffstack/models/agentformer.py deleted file mode 100644 index e46f859..0000000 --- a/diffstack/models/agentformer.py +++ /dev/null @@ -1,3324 +0,0 @@ -import torch -from collections import OrderedDict -from dataclasses import asdict - -from diffstack import dynamics - -torch.manual_seed(0) -torch.cuda.manual_seed_all(0) - -import numpy as np -from torch import nn -from torch.nn import functional as F -from collections import defaultdict -from diffstack.utils.model_utils import ( - AFMLP, - Normal, - Categorical, - initialize_weights, - rotation_2d_torch, - ExpParamAnnealer, -) -from .agentformer_lib import ( - AgentFormerEncoderLayer, - AgentFormerDecoderLayer, - AgentFormerDecoder, - AgentFormerEncoder, -) -from diffstack.models.agentformer_lib import * -import diffstack.utils.tensor_utils as TensorUtils -from diffstack.utils.loss_utils import MultiModal_trajectory_loss -from diffstack.models import base_models -from diffstack.utils.metrics import ( - DynOrnsteinUhlenbeckPerturbation, -) -from diffstack.utils.batch_utils import batch_utils -from diffstack.utils.loss_utils import ( - trajectory_loss, - MultiModal_trajectory_loss, - goal_reaching_loss, - collision_loss, - collision_loss_masked, - log_normal_mixture, - NLL_GMM_loss, - compute_pred_loss, - diversity_score, -) -from diffstack.utils.dist_utils import MAGaussian, MADynGaussian, MAGMM, MADynGMM - - -""" Positional Encoding """ - - -class PositionalAgentEncoding(nn.Module): - def __init__( - self, - d_model, - dropout=0.1, - max_t_len=200, - max_a_len=200, - concat=False, - use_agent_enc=False, - agent_enc_learn=False, - ): - super(PositionalAgentEncoding, self).__init__() - self.dropout = nn.Dropout(p=dropout) - self.concat = concat - self.d_model = d_model - self.use_agent_enc = use_agent_enc - if concat: - self.fc = nn.Linear((3 if use_agent_enc else 2) * d_model, d_model) - - pe = self.build_pos_enc(max_t_len) - self.register_buffer("pe", pe) - if use_agent_enc: - if agent_enc_learn: - self.ae = nn.Parameter(torch.randn(max_a_len, 1, d_model) * 0.1) - else: - ae = self.build_pos_enc(max_a_len) - self.register_buffer("ae", ae) - - def build_pos_enc(self, max_len): - pe = torch.zeros(max_len, self.d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - return pe - - def build_agent_enc(self, max_len): - ae = torch.zeros(max_len, self.d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model) - ) - ae[:, 0::2] = torch.sin(position * div_term) - ae[:, 1::2] = torch.cos(position * div_term) - ae = ae.unsqueeze(0).transpose(0, 1) - return ae - - def get_pos_enc(self, num_t, num_a, t_offset): - pe = self.pe[t_offset : num_t + t_offset, :] - pe = pe.repeat_interleave(num_a, dim=0) - return pe - - def get_agent_enc(self, num_t, num_a, a_offset, agent_enc_shuffle): - if agent_enc_shuffle is None: - ae = self.ae[a_offset : num_a + a_offset, :] - else: - ae = self.ae[agent_enc_shuffle] - ae = ae.repeat(num_t, 1, 1) - return ae - - def forward(self, x, num_a, agent_enc_shuffle=None, t_offset=0, a_offset=0): - num_t = x.shape[0] // num_a - pos_enc = self.get_pos_enc(num_t, num_a, t_offset) - if self.use_agent_enc: - agent_enc = self.get_agent_enc(num_t, num_a, a_offset, agent_enc_shuffle) - if self.concat: - feat = [x, pos_enc.repeat(1, x.size(1), 1)] - if self.use_agent_enc: - feat.append(agent_enc.repeat(1, x.size(1), 1)) - x = torch.cat(feat, dim=-1) - x = self.fc(x) - else: - x += pos_enc - if self.use_agent_enc: - x += agent_enc - return self.dropout(x) - - -""" Context (Past) Encoder """ - - -class ContextEncoder(nn.Module): - def __init__(self, cfg, **kwargs): - super().__init__() - self.cfg = cfg - self.motion_dim = cfg["motion_dim"] - self.model_dim = cfg["tf_model_dim"] - self.ff_dim = cfg["tf_ff_dim"] - self.nhead = cfg["tf_nhead"] - self.dropout = cfg["tf_dropout"] - self.nlayer = cfg["context_encoder"]["nlayer"] - self.input_type = cfg["input_type"] - self.pooling = cfg.pooling - self.agent_enc_shuffle = cfg["agent_enc_shuffle"] - self.vel_heading = cfg["vel_heading"] - in_dim = self.motion_dim * len(self.input_type) - if "map" in self.input_type: - in_dim += cfg.map_encoder.feature_dim - self.motion_dim - self.input_fc = nn.Linear(in_dim, self.model_dim) - - encoder_layers = AgentFormerEncoderLayer( - {}, self.model_dim, self.nhead, self.ff_dim, self.dropout - ) - self.tf_encoder = AgentFormerEncoder(encoder_layers, self.nlayer) - self.pos_encoder = PositionalAgentEncoding( - self.model_dim, - self.dropout, - concat=cfg["pos_concat"], - max_a_len=cfg["max_agent_len"], - use_agent_enc=cfg["use_agent_enc"], - agent_enc_learn=cfg["agent_enc_learn"], - ) - - def forward(self, data): - pre_len, agent_num, bs = ( - data["pre_motion"].size(0), - data["pre_motion"].size(1), - data["pre_motion"].size(2), - ) - PN = pre_len * agent_num - - # get raw features - traj_in = [] - for key in self.input_type: - if key == "pos": - traj_in.append(data["pre_motion"]) # P x N x B x 2 - elif key == "vel": - vel = data["pre_vel"] # P x N x B x 2 - # if len(self.input_type) > 1: - # vel = torch.cat([vel[[0]], vel], dim=0) - if self.vel_heading: - vel = rotation_2d_torch(vel, -data["heading"])[0] - traj_in.append(vel) - elif key == "norm": - traj_in.append(data["pre_motion_norm"]) # P x N x B x 2 - elif key == "scene_norm": - traj_in.append(data["pre_motion_scene_norm"]) # P x N x B x 2 - elif key == "heading": - hv = ( - data["heading_vec"].unsqueeze(0).repeat_interleave(pre_len, dim=0) - ) # P x N x B x 2 - traj_in.append(hv) - elif key == "map": - map_enc = data["map_enc"].unsqueeze(0).repeat((pre_len, 1, 1, 1)) - traj_in.append(map_enc) - else: - raise ValueError("unknown input_type!") - - # extend the agent-pair mask to PN x PN by repeating - # src_agent_mask = data['agent_mask'].clone() # N x N - # src_mask = generate_mask(tf_in.shape[0], tf_in.shape[0], data['agent_num'], src_agent_mask).to(tf_in.device) # PN X PN - - # ******************************** create mask for NaN - - # time-stamp based masking, i.e., not masking for a whole agents - # can only mask part of the agents who have incomplete data - src_mask = ( - data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) - ) # B x PN x 1 - src_mask_square = torch.bmm(src_mask, src_mask.transpose(1, 2)) # B x PN x PN - - # due to the inverse definition in attention.py - # 0 means good, 1 means nan data - enc_mask = (1 - src_mask.transpose(0, 1)).bool() # PN x B x 1 - src_mask_square = (1 - src_mask_square).bool() # B x PN x PN - - # expand mask to head dimensions - src_mask_square = ( - src_mask_square.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * self.nhead, PN, PN) - ) # BH x PN x PN - # repeat_interleave copy for the dimenion that already has sth, e.g., B - # attach the copied dimenion in the end, i.e., BH rather than HB - # the order matters in this case since there are a lot of dimenions - # when printing the matrices, the default is to loop/list from the - # 2nd dimenion, which is H in this case, same for PN (N dim goes first) - - # ******************************** feature encoding - - # mask NaN because even simple fc cannot handle NaN in backward pass - traj_in = torch.cat(traj_in, dim=-1) # P x N x B x feat - traj_in = traj_in.view(PN, bs, traj_in.shape[-1]) # PN x B x feat - traj_in = traj_in.masked_fill_(enc_mask, float(0)) # PN x B x feat - - # input projection - tf_in = self.input_fc(traj_in) # PN x B x feat - tf_in = tf_in.masked_fill_(enc_mask, float(0.0)) - # the resulting features will contain some randome numbers in the - # invalid rows, can suppress using the above comment - # optional: but not masking will not affect the final results - - # ******************************** transformer - - # add positional embedding - agent_enc_shuffle = ( - data["agent_enc_shuffle"] if self.agent_enc_shuffle else None - ) - tf_in_pos = self.pos_encoder( - tf_in, num_a=agent_num, agent_enc_shuffle=agent_enc_shuffle - ) # PN x B x feat - - tf_in_pos = tf_in_pos.masked_fill_(enc_mask, float(0.0)) - # the resulting features will contain some randome numbers in the - # invalid rows, can suppress using the above comment - # optional: but not masking will not affect the final results - - # transformer encoder - assert not torch.isnan(tf_in_pos).any(), "error" - data["context_enc"] = self.tf_encoder( - tf_in_pos, mask=src_mask_square, num_agent=agent_num # BH x PN x PN - ) # PN x B x feat - assert not torch.isnan(data["context_enc"]).any(), "error" - - # mask NaN row (now contained random numbers due to softmax and bias in the linear layers) - # replace random numbers in the NaN rows with 0s to avoid confusion - # here, the masking is needed, otherwise will affect the prior in the pooling - data["context_enc"] = data["context_enc"].masked_fill_( - enc_mask, float(0.0) - ) # PN x B x feat - - # ******************************** compute latent distribution - - # compute per agent context for prior - # using mean will average over a few zeros for the agents with invalid data - context_rs = data["context_enc"].view( - pre_len, agent_num, bs, self.model_dim - ) # P x N x B x feat - if self.pooling == "mean": - data["agent_context"] = torch.mean(context_rs, dim=0) # N x B x feat - else: - data["agent_context"] = torch.max(context_rs, dim=0)[0] - data["agent_context"] = data["agent_context"].view( - agent_num * bs, -1 - ) # NB x feat - - -""" Future Encoder """ - - -class FutureEncoder(nn.Module): - def __init__(self, cfg, **kwargs): - super().__init__() - self.cfg = cfg - self.context_dim = context_dim = cfg["tf_model_dim"] - self.forecast_dim = forecast_dim = cfg["forecast_dim"] - self.nz = cfg["nz"] - self.z_type = cfg["z_type"] - - self.model_dim = cfg["tf_model_dim"] - self.ff_dim = cfg["tf_ff_dim"] - self.nhead = cfg["tf_nhead"] - self.dropout = cfg["tf_dropout"] - self.nlayer = cfg["future_encoder"]["nlayer"] - self.out_mlp_dim = cfg.future_decoder.out_mlp_dim - self.input_type = cfg["fut_input_type"] - self.pooling = cfg.pooling - self.agent_enc_shuffle = cfg.agent_enc_shuffle - self.vel_heading = cfg.vel_heading - # networks - in_dim = forecast_dim * len(self.input_type) - if "map" in self.input_type: - in_dim += cfg.map_encoder.feature_dim - forecast_dim - self.input_fc = nn.Linear(in_dim, self.model_dim) - - decoder_layers = AgentFormerDecoderLayer( - {}, self.model_dim, self.nhead, self.ff_dim, self.dropout - ) - self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) - self.pos_encoder = PositionalAgentEncoding( - self.model_dim, - self.dropout, - concat=cfg["pos_concat"], - max_a_len=cfg["max_agent_len"], - use_agent_enc=cfg["use_agent_enc"], - agent_enc_learn=cfg["agent_enc_learn"], - ) - num_dist_params = ( - 2 * self.nz if self.z_type == "gaussian" else self.nz - ) # either gaussian or discrete - if self.out_mlp_dim is None: - self.q_z_net = nn.Linear(self.model_dim, num_dist_params) - else: - self.out_mlp = AFMLP(self.model_dim, self.out_mlp_dim, "relu") - self.q_z_net = nn.Linear(self.out_mlp.out_dim, num_dist_params) - # initialize - initialize_weights(self.q_z_net.modules()) - - def forward(self, data, reparam=True, temp=0.1): - fut_len, agent_num, bs = ( - data["fut_motion"].size(0), - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - pre_len = data["pre_motion"].size(0) - FN = fut_len * agent_num - PN = pre_len * agent_num - - # get input feature - traj_in = [] - for key in self.input_type: - if key == "pos": - traj_in.append(data["fut_motion"]) # F x N x B x 2 - elif key == "vel": - vel = data["fut_vel"] # F x N x B x 2 - if self.vel_heading: - vel = rotation_2d_torch(vel, -data["heading"])[0] - traj_in.append(vel) - elif key == "norm": - traj_in.append(data["fut_motion_norm"]) # F x N x B x 2 - elif key == "scene_norm": - traj_in.append(data["fut_motion_scene_norm"]) # F x N x B x 2 - elif key == "heading": - hv = ( - data["heading_vec"].unsqueeze(0).repeat_interleave(fut_len, dim=0) - ) # F x N x B x 2 - traj_in.append(hv) - elif key == "map": - map_enc = ( - data["map_enc"] - .unsqueeze(0) - .repeat((data["fut_motion"].shape[0], 1, 1)) - ) - traj_in.append(map_enc) - else: - raise ValueError("unknown input_type!") - - # ******************************** create mask for NaN - - # generate masks, mem_mask for cross attention between past and future, tgt_mask for self_attention between futures - # mem_agent_mask = data['agent_mask'].clone() # N x N - # mem_mask = generate_mask(tf_in.shape[0], data['context_enc'].shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # FN x PN - # tgt_agent_mask = data['agent_mask'].clone() # N x N - # tgt_mask = generate_mask(tf_in.shape[0], tf_in.shape[0], data['agent_num'], tgt_agent_mask).to(tf_in.device) # FN x FN - - # time-stamp based masking, i.e., not masking for a whole agents - # can only mask part of the agents who have incomplete data - fut_mask = ( - data["fut_mask"].transpose(1, 2).contiguous().view(bs, FN, 1) - ) # B x FN x 1 - pre_mask = ( - data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) - ) # B x PN x 1 - mem_mask = torch.bmm(fut_mask, pre_mask.transpose(1, 2)) # B x FN x PN - tgt_mask = torch.bmm(fut_mask, fut_mask.transpose(1, 2)) # B x FN x FN - - # due to the inverse definition in attention.py - # 0 means good, 1 means nan data - enc_mask = (1 - fut_mask.transpose(0, 1)).bool() # FN x B x 1 - mem_mask = (1 - mem_mask).bool() # B x FN x PN - tgt_mask = (1 - tgt_mask).bool() # B x FN x FN - - # expand mask to head dimensions - mem_mask = ( - mem_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * self.nhead, FN, PN) - ) # BH x FN x PN - tgt_mask = ( - tgt_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * self.nhead, FN, FN) - ) # BH x FN x FN - - # ******************************** feature encoding - - # mask NaN because even simple fc cannot handle NaN in backward pass - traj_in = torch.cat(traj_in, dim=-1) # F x N x B x feat - traj_in = traj_in.view(FN, bs, traj_in.shape[-1]) # FN x B x feat - traj_in = traj_in.masked_fill_(enc_mask, float(0)) # FN x B x feat - - # input projection - tf_in = self.input_fc(traj_in) # FN x B x feat - tf_in = tf_in.masked_fill_(enc_mask, float(0.0)) # FN x B x feat - # the resulting features will contain some randome numbers in the - # invalid rows, can suppress using the above comment - # optional: but not masking will not affect the final results - - # ******************************** transformer - - # add positional embedding - agent_enc_shuffle = ( - data["agent_enc_shuffle"] if self.agent_enc_shuffle else None - ) - tf_in_pos = self.pos_encoder( - tf_in, num_a=agent_num, agent_enc_shuffle=agent_enc_shuffle - ) # FN x B x feat - tf_in_pos = tf_in_pos.masked_fill_(enc_mask, float(0.0)) - # the resulting features will contain some randome numbers in the - # invalid rows, can suppress using the above comment - # optional: but not masking will not affect the final results - - # transformer decoder (cross attention between future and context features) - assert not torch.isnan(tf_in_pos).any(), "error" - tf_out, _ = self.tf_decoder( - tf_in_pos, # FN x B x feat - data["context_enc"], # PN x B x feat - memory_mask=mem_mask, # BH x FN x PN - tgt_mask=tgt_mask, # BH x FN x FN - num_agent=agent_num, - ) # FN x B x feat - assert not torch.isnan(tf_out).any(), "error" - - # mask NaN row (now contained random numbers due to softmax and bias in the linear layers) - # replace random numbers in the NaN rows with 0s to avoid confusion - # here, the masking is needed, otherwise will affect the posterior in the pooling - tf_out = tf_out.masked_fill_(enc_mask, float(0.0)) # FN x B x feat - - # ******************************** compute latent distribution - - # compute per agent for posterior - tf_out = tf_out.view(fut_len, agent_num, bs, self.model_dim) # F x N x B x feat - if self.pooling == "mean": - h = torch.mean(tf_out, dim=0) # N x B x feat - else: - h = torch.max(tf_out, dim=0)[0] # N x B x feat - if self.out_mlp_dim is not None: - h = self.out_mlp(h) # N x B x feat - h = h.view(agent_num * bs, -1) # NB x feat - - # ******************************** sample latent code - - # sample latent code from the posterior distribution - # each agent has a separate distribution and sample independently - q_z_params = self.q_z_net(h) # NB x 64 (contain mu and var) - if self.z_type == "gaussian": - data["q_z_dist"] = Normal(params=q_z_params) - else: - data["q_z_dist"] = Categorical(logits=q_z_params, temp=temp) - data["q_z_samp"] = ( - data["q_z_dist"].rsample().reshape(agent_num, bs, -1) - ) # N x B x 32 - - -""" Future Decoder """ - - -class FutureDecoder(nn.Module): - def __init__(self, cfg, **kwargs): - super().__init__() - self.cfg = cfg - self.ar_detach = cfg["ar_detach"] - self.context_dim = context_dim = cfg["tf_model_dim"] - self.forecast_dim = forecast_dim = cfg["forecast_dim"] - self.pred_scale = cfg["pred_scale"] - self.pred_type = cfg["pred_type"] - self.sn_out_type = cfg["sn_out_type"] - self.sn_out_heading = cfg["sn_out_heading"] - self.input_type = cfg["dec_input_type"] - self.future_frames = cfg["future_num_frames"] - self.past_frames = cfg["history_num_frames"] - self.nz = cfg["nz"] - self.z_type = cfg["z_type"] - self.model_dim = cfg["tf_model_dim"] - self.ff_dim = cfg["tf_ff_dim"] - self.nhead = cfg["tf_nhead"] - self.dropout = cfg["tf_dropout"] - self.nlayer = cfg["future_decoder"]["nlayer"] - self.out_mlp_dim = cfg.future_decoder.out_mlp_dim - self.pos_offset = cfg.pos_offset - self.agent_enc_shuffle = cfg["agent_enc_shuffle"] - self.learn_prior = cfg["learn_prior"] - # networks - if self.pred_type in ["dynamic", "dynamic_var"]: - in_dim = 6 + len(self.input_type) * forecast_dim + self.nz - - if cfg.dynamic_type == "Unicycle": - self.dyn = dynamics.Unicycle(cfg.step_time) - else: - raise Exception("not supported dynamic type") - - else: - in_dim = forecast_dim + len(self.input_type) * forecast_dim + self.nz - if "map" in self.input_type: - in_dim += cfg.map_encoder.feature_dim - forecast_dim - self.input_fc = nn.Linear(in_dim, self.model_dim) - - decoder_layers = AgentFormerDecoderLayer( - {}, self.model_dim, self.nhead, self.ff_dim, self.dropout - ) - self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) - - self.pos_encoder = PositionalAgentEncoding( - self.model_dim, - self.dropout, - concat=cfg["pos_concat"], - max_a_len=cfg["max_agent_len"], - use_agent_enc=cfg["use_agent_enc"], - agent_enc_learn=cfg["agent_enc_learn"], - ) - if self.pred_type in ["scene_norm", "vel", "pos", "dynamic"]: - outdim = forecast_dim - elif self.pred_type == "dynamic_var": - outdim = forecast_dim + 2 - if self.out_mlp_dim is None: - self.out_fc = nn.Linear(self.model_dim, outdim) - else: - in_dim = self.model_dim - self.out_mlp = AFMLP(in_dim, self.out_mlp_dim, "relu") - self.out_fc = nn.Linear(self.out_mlp.out_dim, outdim) - initialize_weights(self.out_fc.modules()) - if self.learn_prior: - num_dist_params = ( - 2 * self.nz if self.z_type == "gaussian" else self.nz - ) # either gaussian or discrete - self.p_z_net = nn.Linear(self.model_dim, num_dist_params) - initialize_weights(self.p_z_net.modules()) - - def decode_traj_ar( - self, - data, - mode, - context, - input_dict, - z, - sample_num, - need_weights=False, - cond_idx=None, - ): - # z: N x BS x 32 - - fut_len, agent_num, bs = ( - data["fut_motion"].size(0), - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - pre_len = data["pre_motion"].size(0) - FN = fut_len * agent_num - PN = pre_len * agent_num - device = data["fut_motion"].device - # get input feature, only take the current timestamp as input here - if self.pred_type == "vel": - pre_vel = input_dict["pre_vel"] - fut_vel = input_dict["fut_vel"] - dec_in = torch.cat((pre_vel[[-1]], fut_vel)) # (1+F) x N x BS x 2 - elif self.pred_type == "pos": - pre_motion = input_dict["pre_motion"] - fut_motion = input_dict["fut_motion"] - dec_in = torch.cat((pre_motion[[-1]], fut_motion), 0) # (1+F) x N X BS x 2 - elif self.pred_type == "scene_norm": - pre_motion_scene_norm = input_dict["pre_motion_scene_norm"] - fut_motion_scene_norm = input_dict["fut_motion_scene_norm"] - dec_in = torch.cat( - (pre_motion_scene_norm[[-1]], fut_motion_scene_norm), 0 - ) # (1+F) x N x BS x 2 - elif self.pred_type == "dynamic": - curr_state = input_dict["curr_state"] - pre_state_vec = input_dict["pre_state_vec"] - fut_state_vec = input_dict["fut_state_vec"] - dec_in = torch.cat( - (pre_state_vec[[-1]], fut_state_vec) - ) # (1+F) x N x BS x 6 - dec_state = [curr_state] - elif self.pred_type == "dynamic_var": - curr_state = input_dict["curr_state"] - pre_state_vec = input_dict["pre_state_vec"] - fut_state_vec = input_dict["fut_state_vec"] - dec_in = torch.cat( - (pre_state_vec[[-1]], fut_state_vec) - ) # (1+F) x N x BS x 6 - dec_state = [curr_state] - - else: - dec_in = torch.zeros([1 + fut_len, agent_num, bs * sample_num, 2]).to( - device - ) # (1+F) x N x BS x 2 - - # concatenate conditional input features with latent code - # broadcast to the sample dimension - - z_tiled = z.unsqueeze(0).repeat_interleave(1 + fut_len, 0) - - dec_in = dec_in.view( - (1 + fut_len) * agent_num, bs * sample_num, dec_in.size(-1) - ) # (1+F)N x BS x feat - in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] - - # add additional features such as the map - for key in self.input_type: - if key == "heading": - heading = data["heading_vec"].repeat_interleave( - sample_num, dim=1 - ) # N x BS x 2 - heading_tiled = heading.repeat(1 + fut_len, 1, 1) - - in_arr.append(heading_tiled) - elif key == "map": - map_enc = data["map_enc"].repeat_interleave(sample_num, 1) - map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) - in_arr.append(map_enc_tiled) - else: - raise ValueError("wrong decode input type!") - dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat - device = dec_in.device - orig_dec_in_z_list = list(torch.split(dec_in_z_orig, agent_num)) - updated_dec_in_z_list = list() - dec_in_z = dec_in_z_orig.clone() - - # dec_in_z_padded = torch.cat((dec_in_z,torch.zeros(agent_num*(fut_len-1),bs,D).to(device))) - - # mem_agent_mask = data['agent_mask'].clone() - # tgt_agent_mask = data['agent_mask'].clone() - - if self.pred_type == "dynamic_var": - logvar = list() - - # predict for each timestamps auto-regressively - for fut_index in range(fut_len): - F_tmp = fut_index + 1 - FN_tmp = F_tmp * agent_num - - # ******************************** create mask for NaN - - # agent-wise masking - # mem_mask = pred_utils.generate_mask(tf_in.shape[0], context.shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # (F)N x PN - # tgt_mask = pred_utils.generate_ar_mask(tf_in_pos.shape[0], agent_num, tgt_agent_mask).to(tf_in.device) # (F)N x (F)N - - # time-stamp-based masking - # only using the last timestamp of pre_motion, i.e., the current frame of mask - # repeat it over future frames, i.e., the assumption is that the valid objects - # to predict must have data in the current frame, this is safe since we interpolated - # data in the trajdata, i.e., objects with incomplete trajectories may have NaN in the - # beginning/end of the time window, but not in the current frame - cur_mask = ( - data["pre_mask"][:, :, [-1]] - .transpose(1, 2) - .contiguous() - .view(bs, agent_num, 1) - ) # B x N x 1 - cur_mask = cur_mask.repeat_interleave(sample_num, dim=0) # BS x N x 1 - cur_mask = cur_mask.unsqueeze(1).repeat_interleave(1 + fut_len, dim=1) - - cur_mask[:, F_tmp:] = 0 - if cond_idx is not None: - cur_mask[:, :, cond_idx] = 1 - - cur_mask = cur_mask.view(bs * sample_num, (1 + fut_len) * agent_num, 1) - - pre_mask = ( - data["pre_mask"] - .transpose(1, 2) - .contiguous() - .view(bs, PN, 1) - .repeat_interleave(sample_num, dim=0) - ) # BS x PN x 1 - - mem_mask = torch.bmm(cur_mask, pre_mask.transpose(1, 2)) # BS x (1+F)N x PN - tgt_mask = torch.bmm( - cur_mask, cur_mask.transpose(1, 2) - ) # BS x (1+F)N x (1+F)N - - # due to the inverse definition in attention.py - # 0 means good, 1 means nan data now - cur_mask = (1 - cur_mask.transpose(0, 1)).bool() # (1+F)N x BS x 1 - mem_mask = (1 - mem_mask).bool() # BS x (1+F)N x PN - tgt_mask = (1 - tgt_mask).bool() # BS x (1+F)N x (1+F)N - - # expand mask to head dimensions - mem_mask = ( - mem_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * sample_num * self.nhead, (1 + fut_len) * agent_num, PN) - ) # BSH x (1+F)N x PN - tgt_mask = ( - tgt_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view( - bs * sample_num * self.nhead, - (1 + fut_len) * agent_num, - (1 + fut_len) * agent_num, - ) - ) # BSH x (1+F)N x (1+F)N - - # ******************************** feature encoding - - # mask NaN because even simple fc cannot handle NaN in backward pass - - tf_in = dec_in_z.masked_fill_(cur_mask, float(0)) # (1+F)N x BS x feat - - # input projection - tf_in = self.input_fc( - tf_in - ) # (F)N x BS x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in = tf_in.masked_fill_(cur_mask, float(0.0)) # (1+F)N x BS x feat - - # ******************************** transformer - - # add positional encoding - agent_enc_shuffle = ( - data["agent_enc_shuffle"] if self.agent_enc_shuffle else None - ) - tf_in_pos = self.pos_encoder( - tf_in, - num_a=agent_num, - agent_enc_shuffle=agent_enc_shuffle, - t_offset=self.past_frames - 1 if self.pos_offset else 0, - ) - # (F)N x BS x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in_pos = tf_in_pos.masked_fill_(cur_mask, float(0.0)) - - # transformer decoder (between predicted steps and past context) - assert not torch.isnan(tf_in_pos).any(), "error" - - tf_out, attn_weights = self.tf_decoder( - tf_in_pos, # (F)N x BS x feat - context, # PN x BS x feat - memory_mask=mem_mask, # BSH x (F)N x PN - tgt_mask=tgt_mask, # BSH x (F)N x (F)N - num_agent=agent_num, - need_weights=need_weights, - ) - - assert not torch.isnan(tf_out).any(), "error" - # tf_out: (1+F)N x BS x feat - - # ******************************** output projection - - # convert the output feature to output dimension (x, y) - # out_tmp = tf_out.view(-1, tf_out.shape[-1]) # (F)NS x feat - if self.out_mlp_dim is not None: - out_tmp = self.out_mlp(tf_out) # (F)N x BS x feat - seq_out = self.out_fc(out_tmp) # (F)N x BS x 2 - - # denormalize data and de-rotate - if self.pred_type == "scene_norm" and self.sn_out_type in {"vel", "norm"}: - norm_motion = seq_out.view( - 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] - ) # (1+F) x N x BS x 2 - - # aggregate velocity prediction to obtain location - if self.sn_out_type == "vel": - norm_motion = torch.cumsum(norm_motion, dim=0) # (1+F) x N x BS x 2 - - # default not used - if self.sn_out_heading: - angles = data["heading"].repeat_interleave(sample_num) - norm_motion = rotation_2d_torch(norm_motion, angles)[0] - - # denormalize over the scene - # we are predicting delta with respect to the current frame of data - # will introduce NaN here since the scene_norm data in the current frame has NaN - seq_out = ( - norm_motion + pre_motion_scene_norm[[-1]] - ) # (1+F) x N x BS x 2 - dec_feat_in = seq_out.view( - (1 + fut_len) * agent_num, bs * sample_num, seq_out.shape[-1] - ) # (1+F)N x BS x 2 - elif self.pred_type in ["dynamic", "dynamic_var"]: - traj_scale = data["traj_scale"] - input_seq = TensorUtils.reshape_dimensions_single( - seq_out[..., : self.forecast_dim], 0, 1, [fut_len + 1, -1] - ).permute(1, 2, 0, 3) - - # curr_state_xyhv = torch.cat((curr_state[...,:2],curr_state[...,3:],curr_state[...,2:3]),-1) - state_seq = self.dyn.forward_dynamics(curr_state, input_seq[..., 1:, :]) - # state_seq = torch.cat((state_seq[...,:2],state_seq[...,3:],state_seq[...,2:3]),-1) - state_seq = state_seq.permute(2, 0, 1, 3) - state_seq = torch.cat((curr_state.unsqueeze(0), state_seq), 0) - yaw = state_seq[..., 3:] - vel = state_seq[..., 2:3] / traj_scale - cosyaw = torch.cos(yaw) - sinyaw = torch.sin(yaw) - dec_feat_in = TensorUtils.join_dimensions( - torch.cat( - ( - state_seq[..., :2] / traj_scale, - vel * cosyaw, - vel * sinyaw, - cosyaw, - sinyaw, - ), - -1, - ), - 0, - 2, - ) - - # ******************************** prepare for the next timestamp - - # only take the last few results for the N agents predicted in the last timestamp - if self.ar_detach: - out_in = ( - dec_feat_in[F_tmp * agent_num : (1 + F_tmp) * agent_num] - .clone() - .detach() - ) # N x BS x 2(6) - else: - out_in = dec_feat_in[ - F_tmp * agent_num : (1 + F_tmp) * agent_num - ] # N x BS x 2(6) - - # create input for the next timestamp - in_arr = [out_in, z] # z: N x BS x 32 - - for key in self.input_type: - if key == "heading": - in_arr.append(heading) # z: N x BS x 2 - elif key == "map": - in_arr.append(map_enc) - else: - raise ValueError("wrong decoder input type!") - - # combine with previous information, data in normal forward order - # i.e., newly predicted information attached in the end of features - out_in_z = torch.cat(in_arr, dim=-1) # N x BS x feat - updated_dec_in_z_list.append(out_in_z) - # import pdb - # pdb.set_trace() - curr_dec_list = ( - orig_dec_in_z_list[0:1] - + updated_dec_in_z_list - + orig_dec_in_z_list[F_tmp + 1 :] - ) - dec_in_z = torch.cat(curr_dec_list, 0) - # dec_in_z[F_tmp*agent_num:(1+F_tmp)*agent_num] = out_in_z - - # seq_out: FN x BS x 2 - seq_out = seq_out.view( - 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] - ) # 1+F x N x BS x 2 - seq_out = seq_out[1:] # F x N x BS x 2 - data[f"{mode}_seq_out"] = seq_out - - if self.pred_type == "vel": - dec_motion = torch.cumsum(seq_out, dim=0) # F x N x BS x 2 - dec_motion += pre_motion[[-1]] # F x N X BS x 2 - elif self.pred_type == "pos": - dec_motion = seq_out.clone() - elif self.pred_type == "scene_norm": - dec_motion = seq_out + data["scene_orig"].repeat_interleave( - sample_num, dim=0 - ) # F x N X BS x 2 - elif self.pred_type in ["dynamic", "dynamic_var"]: - input_seq = seq_out.permute(1, 2, 0, 3) - # curr_state_xyhv = torch.cat((curr_state[...,:2],curr_state[...,3:],curr_state[...,2:3]),-1) - state_seq = self.dyn.forward_dynamics( - curr_state, input_seq[..., : self.forecast_dim] - ) - # state_seq = torch.cat((state_seq[...,:2],state_seq[...,3:],state_seq[...,2:3]),-1) - state_seq = state_seq.permute(2, 0, 1, 3) - dec_state = state_seq - dec_motion = state_seq[..., :2] / data["traj_scale"] - data["controls"] = ( - input_seq[..., : self.forecast_dim] - .transpose(0, 2) - .contiguous() - .view(bs, sample_num, agent_num, fut_len, self.forecast_dim) - ) - else: - dec_motion = seq_out + pre_motion[[-1]] # F x N X BS x 2 - - # reshape for loss computation - dec_motion = dec_motion.transpose(0, 2).contiguous() # BS x N x F x 2 - - dec_motion = dec_motion.view( - bs, sample_num, agent_num, fut_len, dec_motion.size(-1) - ) # B x S x N x F x 2 - if self.pred_type in ["dynamic", "dynamic_var"]: - dec_state = ( - dec_state.transpose(0, 2) - .contiguous() - .view(bs, sample_num, agent_num, fut_len, dec_state.size(-1)) - ) - data[f"{mode}_dec_state"] = dec_state - if self.pred_type == "dynamic_var": - logvar = seq_out[..., self.forecast_dim : 2 * self.forecast_dim] - var = torch.exp(logvar) * data["traj_scale"] ** 2 - var = ( - var.permute(2, 1, 0, 3) - .contiguous() - .view(bs, sample_num, agent_num, fut_len, var.size(-1)) - ) - data[f"{mode}_var"] = var - - data[f"{mode}_dec_motion"] = dec_motion - if need_weights: - data["attn_weights"] = attn_weights - - def decode_traj_batch( - self, - data, - mode, - context, - input_dict, - z, - sample_num, - ): - raise NotImplementedError - - def forward( - self, - data, - mode, - sample_num=1, - autoregress=True, - z=None, - need_weights=False, - cond_idx=None, - temp=0.1, - predict=False, - ): - agent_num, bs = ( - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - - # conditional input to the decoding process - context = data["context_enc"].repeat_interleave( - sample_num, dim=1 - ) # PN x BS x feat - - pre_motion = data["pre_motion"].repeat_interleave( - sample_num, dim=2 - ) # P x N X BS x 2 - fut_motion = data["fut_motion"].repeat_interleave( - sample_num, dim=2 - ) # F x N X BS x 2 - pre_motion_scene_norm = data["pre_motion_scene_norm"].repeat_interleave( - sample_num, dim=2 - ) # P x N x BS x 2 - fut_motion_scene_norm = data["fut_motion_scene_norm"].repeat_interleave( - sample_num, dim=2 - ) # F x N x BS x 2 - input_dict = dict( - pre_motion=pre_motion, - fut_motion=fut_motion, - pre_motion_scene_norm=pre_motion_scene_norm, - fut_motion_scene_norm=fut_motion_scene_norm, - ) - if self.pred_type == "vel": - input_dict["pre_vel"] = data["pre_vel"].repeat_interleave( - sample_num, dim=2 - ) # P x N x BS x 2 - input_dict["fut_vel"] = data["fut_vel"].repeat_interleave( - sample_num, dim=2 - ) # F x N x BS x 2 - elif self.pred_type in ["dynamic", "dynamic_var"]: - traj_scale = data["traj_scale"] - pre_state = torch.cat( - ( - data["pre_motion"] * traj_scale, - torch.norm(data["pre_vel"], dim=-1, keepdim=True) * traj_scale, - data["pre_heading_raw"].transpose(0, 2).unsqueeze(-1), - ), - -1, - ) # P x N x B x 4 (unscaled) - - pre_state_vec = torch.cat( - (data["pre_motion"], data["pre_vel"], data["pre_heading_vec"]), -1 - ) # P x N x B x 6 (scaled) - fut_state_vec = torch.cat( - (data["fut_motion"], data["fut_vel"], data["fut_heading_vec"]), -1 - ) # F x N x B x 6 (scaled) - input_dict["curr_state"] = pre_state[-1].repeat_interleave( - sample_num, dim=1 - ) - input_dict["pre_state_vec"] = pre_state_vec.repeat_interleave( - sample_num, dim=2 - ) - input_dict["fut_state_vec"] = fut_state_vec.repeat_interleave( - sample_num, dim=2 - ) - - # p(z), compute prior distribution - if mode == "infer": - prior_key = "p_z_dist_infer" - else: - prior_key = "q_z_dist" if "q_z_dist" in data else "p_z_dist" - - if self.learn_prior: - p_z_params0 = self.p_z_net(data["agent_context"]) - - h = data["agent_context"].repeat_interleave(sample_num, dim=0) # NBS x feat - p_z_params = self.p_z_net(h) # NBS x 64 - if self.z_type == "gaussian": - data["p_z_dist_infer"] = Normal(params=p_z_params) - data["p_z_dist"] = Normal(params=p_z_params0) - else: - data["p_z_dist_infer"] = Categorical(logits=p_z_params, temp=temp) - data["p_z_dist"] = Categorical(logits=p_z_params0, temp=temp) - else: - if self.z_type == "gaussian": - data[prior_key] = Normal( - mu=torch.zeros(pre_motion.shape[1], self.nz).to(pre_motion.device), - logvar=torch.zeros(pre_motion.shape[1], self.nz).to( - pre_motion.device - ), - ) - else: - data[prior_key] = Categorical( - logits=torch.zeros(pre_motion.shape[1], self.nz).to( - pre_motion.device - ) - ) - - # sample latent code from the distribution - if z is None: - # use latent code z from posterior for training - if mode == "train": - z = data["q_z_samp"] # N x B x 32 - - # use latent code z from posterior for evaluating the reconstruction loss - elif mode == "recon": - z = data["q_z_dist"].mode() # NB x 32 - z = z.view(agent_num, bs, z.size(-1)) # N x B x 32 - - # use latent code z from the prior for inference - elif mode == "infer": - # dist = data["p_z_dist_infer"] if "p_z_dist_infer" in data else data["q_z_dist_infer"] - # z = dist.sample() # NBS x 32 - # import pdb - # pdb.set_trace() - - dist = ( - data["q_z_dist"] - if data["q_z_dist"] is not None - else data["p_z_dist"] - ) - if self.z_type == "gaussian": - if predict: - z = dist.pseudo_sample(sample_num) - else: - z = data["p_z_dist_infer"].sample() - D = z.shape[-1] - samples = z.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3) - mu = dist.mu.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3)[ - :, :, 0 - ] - sigma = dist.sigma.reshape(agent_num, bs, -1, D).permute( - 1, 0, 2, 3 - )[:, :, 0] - data["prob"] = self.pseudo_sample_prob( - samples, mu, sigma, data["agent_avail"] - ) - elif self.z_type == "discrete": - if predict: - z = dist.pseudo_sample(sample_num).contiguous() - else: - z = dist.rsample(sample_num).contiguous() - D = z.shape[-1] - idx = z.argmax(dim=-1) - prob_sample = torch.gather(dist.probs, -1, idx) - prob_sample = prob_sample.reshape(agent_num, bs, -1).mean(0) - prob_sample = prob_sample / prob_sample.sum(-1, keepdim=True) - data["prob"] = prob_sample - - z = z.view(agent_num, bs * sample_num, z.size(-1)) # N x BS x 32 - else: - raise ValueError("Unknown Mode!") - - # trajectory decoding - if autoregress: - self.decode_traj_ar( - data, - mode, - context, - input_dict, - z, - sample_num, - need_weights=need_weights, - cond_idx=cond_idx, - ) - # self.decode_traj_ar_orig( - # data, - # mode, - # context, - # pre_motion, - # pre_vel, - # pre_motion_scene_norm, - # z, - # sample_num, - # need_weights=need_weights, - # ) - else: - self.decode_traj_batch( - data, - mode, - context, - input_dict, - z, - sample_num, - ) - - def pseudo_sample_prob(self, sample, mu, sigma, mask): - """ - A simple K-means estimation to estimate the probability of samples - """ - bs, Na, Ns, D = sample.shape - device = sample.device - Np = Ns * 50 - particle = torch.randn([bs, Na, Np, D]).to(device) * sigma.unsqueeze( - -2 - ) + mu.unsqueeze(-2) - dis = torch.linalg.norm(sample.unsqueeze(-2) - particle.unsqueeze(-3), dim=-1) - dis = (dis * mask[..., None, None]).sum(1) - idx = torch.argmin(dis, -2) - flag = idx.unsqueeze(1) == torch.arange(Ns).view(1, Ns, 1).repeat_interleave( - bs, 0 - ).to(device) - prob = flag.sum(-1) / Np - return prob - - -class FutureARDecoder(nn.Module): - def __init__(self, cfg, **kwargs): - super().__init__() - self.cfg = cfg - self.ar_detach = cfg["ar_detach"] - self.context_dim = context_dim = cfg["tf_model_dim"] - self.forecast_dim = forecast_dim = cfg["forecast_dim"] - self.pred_scale = cfg["pred_scale"] - self.pred_type = cfg["pred_type"] - self.sn_out_type = cfg["sn_out_type"] - self.sn_out_heading = cfg["sn_out_heading"] - self.input_type = cfg["dec_input_type"] - self.future_frames = cfg["future_num_frames"] - self.past_frames = cfg["history_num_frames"] - self.z_type = cfg["z_type"] - self.nz = cfg["nz"] if self.z_type != "None" else 0 - self.model_dim = cfg["tf_model_dim"] - self.ff_dim = cfg["tf_ff_dim"] - self.nhead = cfg["tf_nhead"] - self.dropout = cfg["tf_dropout"] - self.nlayer = cfg["future_decoder"]["nlayer"] - self.out_mlp_dim = cfg.future_decoder.out_mlp_dim - self.pos_offset = cfg.pos_offset - self.agent_enc_shuffle = cfg["agent_enc_shuffle"] - self.learn_prior = cfg["learn_prior"] - # networks - assert self.pred_type == "dynamic_AR" - in_dim = 6 + len(self.input_type) * forecast_dim + self.nz - - if cfg.dynamic_type == "Unicycle": - self.dyn = dynamics.Unicycle(cfg.step_time) - else: - raise Exception("not supported dynamic type") - - if "map" in self.input_type: - in_dim += cfg.map_encoder.feature_dim - forecast_dim - self.input_fc = nn.Linear(in_dim, self.model_dim) - - decoder_layers = AgentFormerDecoderLayer( - {}, self.model_dim, self.nhead, self.ff_dim, self.dropout - ) - self.tf_decoder = AgentFormerDecoder(decoder_layers, self.nlayer) - - self.pos_encoder = PositionalAgentEncoding( - self.model_dim, - self.dropout, - concat=cfg["pos_concat"], - max_a_len=cfg["max_agent_len"], - use_agent_enc=cfg["use_agent_enc"], - agent_enc_learn=cfg["agent_enc_learn"], - ) - if cfg.dist_type == "gaussian": - outdim = self.dyn.udim * 2 + cfg.scene_var_dim * self.dyn.udim - if cfg.output_varx and cfg.dist_obj == "state": - outdim += self.dyn.xdim - elif cfg.dist_type == "GMM": - self.GMM_M = cfg.GMM_M - outdim = (forecast_dim * 2 + cfg.scene_var_dim * forecast_dim) * self.GMM_M - if cfg.output_varx and cfg.dist_obj == "state": - outdim += self.dyn.xdim * self.GMM_M - self.GMM_pi_net = nn.Linear(self.model_dim, self.GMM_M) - - if self.out_mlp_dim is None: - self.out_fc = nn.Linear(self.model_dim, outdim) - else: - in_dim = self.model_dim - self.out_mlp = AFMLP(in_dim, self.out_mlp_dim, "relu") - self.out_fc = nn.Linear(self.out_mlp.out_dim, outdim) - initialize_weights(self.out_fc.modules()) - if self.learn_prior and self.z_type != "None": - num_dist_params = ( - 2 * self.nz if self.z_type == "gaussian" else self.nz - ) # either gaussian or discrete - self.p_z_net = nn.Linear(self.model_dim, num_dist_params) - initialize_weights(self.p_z_net.modules()) - - def decode_traj_ar( - self, - data, - mode, - context, - input_dict, - z, - sample_num, - gt_step, - need_weights=False, - ): - # z: N x BS x 32 - fut_len, agent_num, bs = ( - data["fut_motion"].size(0), - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - # assert mode=="infer" - gt_step = gt_step if gt_step < fut_len else fut_len - 1 - pre_len = data["pre_motion"].size(0) - FN = fut_len * agent_num - PN = pre_len * agent_num - # get input feature, only take the current timestamp as input here - - curr_state = input_dict["curr_state"] - pre_state_vec = input_dict["pre_state_vec"] - fut_state_vec = input_dict["fut_state_vec"] - dec_in = torch.cat((pre_state_vec[[-1]], fut_state_vec)) # (1+F) x N x BS x 6 - dec_state = [curr_state] - - # concatenate conditional input features with latent code - # broadcast to the sample dimension - - dec_in = dec_in.view( - (1 + fut_len) * agent_num, bs * sample_num, dec_in.size(-1) - ) # (1+F)N x BS x feat - if z is not None: - z_tiled = z.unsqueeze(0).repeat_interleave(1 + fut_len, 0) - in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] - else: - in_arr = [dec_in] - - # add additional features such as the map - for key in self.input_type: - if key == "heading": - heading = data["heading_vec"].repeat_interleave( - sample_num, dim=1 - ) # N x BS x 2 - heading_tiled = heading.repeat(1 + fut_len, 1, 1) - - in_arr.append(heading_tiled) - elif key == "map": - map_enc = data["map_enc"].repeat_interleave(sample_num, 1) - map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) - in_arr.append(map_enc_tiled) - else: - raise ValueError("wrong decode input type!") - dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat - orig_dec_in_z_list = list(torch.split(dec_in_z_orig, agent_num)) - updated_dec_in_z_list = list() - dec_in_z = dec_in_z_orig.clone() - - # mem_agent_mask = data['agent_mask'].clone() - # tgt_agent_mask = data['agent_mask'].clone() - - # predict for each timestamps auto-regressively - - input_pred = [ - torch.zeros( - [agent_num, bs * sample_num, self.dyn.udim], device=curr_state.device - ) - for _ in range(fut_len) - ] - state_pred = [torch.zeros_like(curr_state) for _ in range(fut_len + 1)] - state_pred[0] = curr_state - for i in range(1, 1 + gt_step): - state_pred[i] = input_dict["fut_state"][i - 1] - - for fut_index in range(gt_step, fut_len): - F_tmp = fut_index + 1 - - # ******************************** create mask for NaN - - # agent-wise masking - # mem_mask = pred_utils.generate_mask(tf_in.shape[0], context.shape[0], data['agent_num'], mem_agent_mask).to(tf_in.device) # (F)N x PN - # tgt_mask = pred_utils.generate_ar_mask(tf_in_pos.shape[0], agent_num, tgt_agent_mask).to(tf_in.device) # (F)N x (F)N - - # time-stamp-based masking - # only using the last timestamp of pre_motion, i.e., the current frame of mask - # repeat it over future frames, i.e., the assumption is that the valid objects - # to predict must have data in the current frame, this is safe since we interpolated - # data in the trajdata, i.e., objects with incomplete trajectories may have NaN in the - # beginning/end of the time window, but not in the current frame - cur_mask = ( - data["pre_mask"][:, :, [-1]] - .transpose(1, 2) - .contiguous() - .view(bs, agent_num, 1) - ) # B x N x 1 - cur_mask = cur_mask.repeat_interleave(sample_num, dim=0) # BS x N x 1 - cur_mask = cur_mask.unsqueeze(1).repeat_interleave(1 + fut_len, dim=1) - - cur_mask[:, F_tmp:] = 0 - - cur_mask = cur_mask.view(bs * sample_num, (1 + fut_len) * agent_num, 1) - - pre_mask = ( - data["pre_mask"] - .transpose(1, 2) - .contiguous() - .view(bs, PN, 1) - .repeat_interleave(sample_num, dim=0) - ) # BS x PN x 1 - - mem_mask = torch.bmm(cur_mask, pre_mask.transpose(1, 2)) # BS x (1+F)N x PN - tgt_mask = torch.bmm( - cur_mask, cur_mask.transpose(1, 2) - ) # BS x (1+F)N x (1+F)N - - # due to the inverse definition in attention.py - # 0 means good, 1 means nan data now - cur_mask = (1 - cur_mask.transpose(0, 1)).bool() # (1+F)N x BS x 1 - mem_mask = (1 - mem_mask).bool() # BS x (1+F)N x PN - tgt_mask = (1 - tgt_mask).bool() # BS x (1+F)N x (1+F)N - - # expand mask to head dimensions - mem_mask = ( - mem_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * sample_num * self.nhead, (1 + fut_len) * agent_num, PN) - ) # BSH x (1+F)N x PN - tgt_mask = ( - tgt_mask.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view( - bs * sample_num * self.nhead, - (1 + fut_len) * agent_num, - (1 + fut_len) * agent_num, - ) - ) # BSH x (1+F)N x (1+F)N - - # ******************************** feature encoding - - # mask NaN because even simple fc cannot handle NaN in backward pass - - tf_in = dec_in_z.masked_fill_(cur_mask, float(0)) # (1+F)N x BS x feat - - # input projection. - tf_in = self.input_fc( - tf_in - ) # (F)N x BS x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in = tf_in.masked_fill_(cur_mask, float(0.0)) # (1+F)N x BS x feat - - # ******************************** transformer - - # add positional encoding - agent_enc_shuffle = ( - data["agent_enc_shuffle"] if self.agent_enc_shuffle else None - ) - tf_in_pos = self.pos_encoder( - tf_in, - num_a=agent_num, - agent_enc_shuffle=agent_enc_shuffle, - t_offset=self.past_frames - 1 if self.pos_offset else 0, - ) - # (F)N x BS x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in_pos = tf_in_pos.masked_fill_(cur_mask, float(0.0)) - - # transformer decoder (between predicted steps and past context) - assert not torch.isnan(tf_in_pos).any(), "error" - - tf_out, attn_weights = self.tf_decoder( - tf_in_pos, # (F)N x BS x feat - context, # PN x BS x feat - memory_mask=mem_mask, # BSH x (F)N x PN - tgt_mask=tgt_mask, # BSH x (F)N x (F)N - num_agent=agent_num, - need_weights=need_weights, - ) - - assert not torch.isnan(tf_out).any(), "error" - # tf_out: (1+F)N x BS x feat - - # ******************************** output projection - - # convert the output feature to output dimension (x, y) - # out_tmp = tf_out.view(-1, tf_out.shape[-1]) # (F)NS x feat - x_t = state_pred[fut_index] - if self.out_mlp_dim is not None: - out_tmp = self.out_mlp(tf_out) # (F)N x BS x feat - seq_out = self.out_fc(out_tmp) # (F)N x BS x 2 - seq_out_T = TensorUtils.reshape_dimensions_single( - seq_out, 0, 1, (fut_len + 1, -1) - ) - xdim, udim = self.dyn.xdim, self.dyn.udim - if self.cfg.dist_obj == "state": - if self.cfg.dist_type == "gaussian": - seq_out_t = seq_out_T[fut_index].transpose(0, 1) - mu_u = seq_out_t[..., :udim].reshape( - bs * sample_num, agent_num, udim - ) - logvar_u = seq_out_t[..., udim : 2 * udim].reshape( - bs * sample_num, agent_num, udim - ) - if self.cfg.output_varx: - logvar_x = seq_out_t[..., 2 * udim : 2 * udim + xdim].reshape( - bs * sample_num, agent_num, xdim - ) - var_x = torch.exp(logvar_x) - K = seq_out_t[..., 2 * udim + xdim :].reshape( - bs * sample_num, agent_num, udim, self.cfg.scene_var_dim - ) - else: - K = seq_out_t[..., 2 * udim :].reshape( - bs * sample_num, agent_num, udim, self.cfg.scene_var_dim - ) - var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ - None, None - ] * torch.ones( - [bs * sample_num, agent_num, xdim], device=mu_u.device - ) - dist = MADynGaussian(mu_u, torch.exp(logvar_u), var_x, K, self.dyn) - elif self.cfg.dist_type == "GMM": - seq_out_t = ( - seq_out_T[fut_index] - .transpose(0, 1) - .reshape(bs * sample_num, agent_num, self.GMM_M, -1) - .transpose(1, 2) - ) - mu_u = seq_out_t[..., : self.forecast_dim] - logvar_u = seq_out_t[..., self.forecast_dim : 2 * self.forecast_dim] - if self.cfg.output_varx: - logvar_x = seq_out_t[..., 2 * udim : 2 * udim + xdim].reshape( - bs * sample_num, self.GMM_M, agent_num, xdim - ) - var_x = torch.exp(logvar_x) - K = seq_out_t[..., 2 * udim + xdim :].reshape( - bs * sample_num, - self.GMM_M, - agent_num, - udim, - self.cfg.scene_var_dim, - ) - else: - K = seq_out_t[..., 2 * udim :].reshape( - *seq_out_t.shape[:-1], udim, self.cfg.scene_var_dim - ) - var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ - None, None, None - ] * torch.ones( - [bs * sample_num, self.GMM_M, agent_num, self.dyn.xdim], - device=mu_u.device, - ) - tf_feature_pooled = tf_out.reshape( - [fut_len + 1, agent_num, bs * sample_num, -1] - )[fut_index].max(0)[0] - logpi = self.GMM_pi_net(tf_feature_pooled) - pi = torch.softmax(logpi, dim=-1) - dist = MADynGMM(mu_u, torch.exp(logvar_u), var_x, K, pi, self.dyn) - - # mu_u = seq_out_T[...,:self.forecast_dim] - # logvar_u = seq_out_T[...,self.forecast_dim:2*self.forecast_dim] - # scene_var_M = seq_out_T[...,2*self.forecast_dim:].reshape(*seq_out_T.shape[:-1],self.forecast_dim,self.cfg.scene_var_dim) - # var_u = torch.exp(logvar_u) - # scene_noise = torch.randn(*scene_var_M.shape[1:-2],self.cfg.scene_var_dim).to(mu_u.device) - - # u_t_sample = mu_u[fut_index]+torch.randn_like(mu_u[fut_index])*torch.sqrt(var_u[fut_index]) + (scene_var_M[fut_index]@scene_noise.unsqueeze(-1)).squeeze(-1) - - # u_t_sample = u_t_sample.squeeze(-2) - # input_pred[fut_index] = u_t_sample - - # xp = self.dyn.step(x_t,u_t_sample) - xp = ( - dist.rsample( - x_t.transpose(0, 1).reshape(bs * sample_num, agent_num, -1), 1 - ) - .squeeze(1) - .transpose(0, 1) - ) - - elif self.cfg.dist_obj == "input": - if self.cfg.dist_type == "gaussian": - seq_out_t = seq_out_T[fut_index].transpose(0, 1) - mu_u = seq_out_t[..., :udim].reshape( - bs * sample_num, agent_num, udim - ) - logvar_u = seq_out_t[..., udim : 2 * udim].reshape( - bs * sample_num, agent_num, udim - ) - K = seq_out_t[..., 2 * udim :].reshape( - bs * sample_num, agent_num, udim, self.cfg.scene_var_dim - ) - dist = MAGaussian(mu_u, torch.exp(logvar_u), K) - - elif self.cfg.dist_type == "GMM": - seq_out_t = ( - seq_out_T[fut_index] - .transpose(0, 1) - .reshape(bs * sample_num, agent_num, self.GMM_M, -1) - .transpose(1, 2) - ) - mu_u = seq_out_t[..., : self.forecast_dim] - logvar_u = seq_out_t[..., self.forecast_dim : 2 * self.forecast_dim] - K = seq_out_t[..., 2 * self.forecast_dim :].reshape( - *seq_out_t.shape[:-1], self.forecast_dim, self.cfg.scene_var_dim - ) - tf_feature_pooled = tf_out.reshape( - [fut_len + 1, agent_num, bs * sample_num, -1] - )[fut_index].max(0)[0] - logpi = self.GMM_pi_net(tf_feature_pooled) - pi = torch.softmax(logpi, dim=-1) - dist = MAGMM(mu_u, torch.exp(logvar_u), K, pi) - - up = dist.rsample(1).squeeze(1).transpose(0, 1) - xp = self.dyn.step(x_t, up) - # denormalize data and de-rotate - state_pred[fut_index + 1] = xp - traj_scale = data["traj_scale"] - state_seq = torch.stack(state_pred, 0).clone() - yaw = state_seq[..., 3:] - vel = state_seq[..., 2:3] / traj_scale - cosyaw = torch.cos(yaw) - sinyaw = torch.sin(yaw) - dec_feat_in = TensorUtils.join_dimensions( - torch.cat( - ( - state_seq[..., :2] / traj_scale, - vel * cosyaw, - vel * sinyaw, - cosyaw, - sinyaw, - ), - -1, - ), - 0, - 2, - ) - - # ******************************** prepare for the next timestamp - - # only take the last few results for the N agents predicted in the last timestamp - if self.ar_detach: - out_in = dec_feat_in[ - F_tmp * agent_num : (1 + F_tmp) * agent_num - ].detach() # N x BS x 2(6) - else: - out_in = dec_feat_in[ - F_tmp * agent_num : (1 + F_tmp) * agent_num - ] # N x BS x 2(6) - - # create input for the next timestamp - if z is not None: - in_arr = [out_in, z] # z: N x BS x 32 - else: - in_arr = [out_in] - - for key in self.input_type: - if key == "heading": - in_arr.append(heading) # z: N x BS x 2 - elif key == "map": - in_arr.append(map_enc) - else: - raise ValueError("wrong decoder input type!") - - # combine with previous information, data in normal forward order - # i.e., newly predicted information attached in the end of features - out_in_z = torch.cat(in_arr, dim=-1) # N x BS x feat - updated_dec_in_z_list.append(out_in_z) - - curr_dec_list = ( - orig_dec_in_z_list[: gt_step + 1] - + updated_dec_in_z_list - + orig_dec_in_z_list[F_tmp + 1 :] - ) - dec_in_z = torch.cat(curr_dec_list, 0) - del updated_dec_in_z_list, orig_dec_in_z_list - # seq_out: FN x BS x 2 - seq_out = seq_out.view( - 1 + fut_len, agent_num, bs * sample_num, seq_out.shape[-1] - ) # 1+F x N x BS x 2 - seq_out = seq_out[1:] # F x N x BS x 2 - data[f"{mode}_seq_out"] = seq_out - - state_pred = torch.stack(state_pred, 0) - input_pred = torch.stack(input_pred, 0) - - dec_state = state_pred[1:] - dec_motion = state_pred[1:, ..., :2] / data["traj_scale"] - data["controls"] = input_pred[..., : self.dyn.udim].permute(1, 2, 0, 3) - - # reshape for loss computation - dec_motion = dec_motion.transpose(0, 2).contiguous() # BS x N x F x 2 - dec_motion = dec_motion.view( - bs, sample_num, agent_num, fut_len, dec_motion.size(-1) - ) # B x S x N x F x 2 - dec_state = ( - dec_state.transpose(0, 2) - .contiguous() - .view(bs, sample_num, agent_num, fut_len, dec_state.size(-1)) - ) - data[f"{mode}_dec_state"] = dec_state - - data[f"{mode}_dec_motion"] = dec_motion - if need_weights: - data["attn_weights"] = attn_weights - - def calc_traj_likelihood( - self, - data, - context, - input_dict, - z, - ): - # z: N x BS x 32 - fut_len, agent_num, bs = ( - data["fut_motion"].size(0), - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - pre_len = data["pre_motion"].size(0) - FN = fut_len * agent_num - PN = pre_len * agent_num - # get input feature, only take the current timestamp as input here - - curr_state = input_dict["curr_state"][:, :bs] - pre_state_vec = input_dict["pre_state_vec"][:, :, :bs] - fut_state_vec = input_dict["fut_state_vec"][:, :, :bs] - dec_in = torch.cat((pre_state_vec[[-1]], fut_state_vec)) # (1+F) x N x B x 6 - dec_state = [curr_state] - - # concatenate conditional input features with latent code - # broadcast to the sample dimension - - dec_in = dec_in.view( - (1 + fut_len) * agent_num, bs, dec_in.size(-1) - ) # (1+F)N x B x feat - if z is not None: - z_tiled = z[:, :bs].unsqueeze(0).repeat_interleave(1 + fut_len, 0) - in_arr = [dec_in, TensorUtils.join_dimensions(z_tiled, 0, 2)] - else: - in_arr = [dec_in] - - # add additional features such as the map - for key in self.input_type: - if key == "heading": - heading = data["heading_vec"] # N x B x 2 - heading_tiled = heading.repeat(1 + fut_len, 1, 1) - - in_arr.append(heading_tiled) - elif key == "map": - map_enc = data["map_enc"] - map_enc_tiled = map_enc.repeat(1 + fut_len, 1, 1) - in_arr.append(map_enc_tiled) - else: - raise ValueError("wrong decode input type!") - dec_in_z_orig = torch.cat(in_arr, dim=-1) # (1)N x BS x feat - dec_in_z = dec_in_z_orig.clone() - - # predict for each timestamps auto-regressively - - state_seq = torch.cat( - [curr_state.unsqueeze(0), input_dict["fut_state"][:, :, :bs]], 0 - ) - - cur_mask = ( - data["pre_mask"][:, :, [-1]] - .transpose(1, 2) - .contiguous() - .view(bs, agent_num, 1) - ).detach() # B x N x 1 - cur_mask = cur_mask.unsqueeze(1).repeat_interleave( - 1 + fut_len, dim=1 - ) # B x (1+F) x Na x 1 - cur_mask_tiled = cur_mask.unsqueeze(1).repeat_interleave( - fut_len, dim=1 - ) # B x F x (1+F) x Na x 1 - fut_mask = torch.tril(torch.ones([fut_len, fut_len + 1]), 0).to(cur_mask.device) - cur_mask_tiled *= fut_mask[None, :, :, None, None] - cur_mask_tiled = cur_mask_tiled.view(-1, (1 + fut_len) * agent_num, 1) - pre_mask = ( - data["pre_mask"].transpose(1, 2).contiguous().view(bs, PN, 1) - ).detach() # BS x PN x 1 - pre_mask_tiled = pre_mask.repeat_interleave(fut_len, dim=0) # BSF x PN x 1 - mem_mask_tiled = torch.bmm( - cur_mask_tiled, pre_mask_tiled.transpose(1, 2) - ) # BSF x (1+F)N x PN - tgt_mask_tiled = torch.bmm( - cur_mask_tiled, cur_mask_tiled.transpose(1, 2) - ) # BSF x (1+F)N x (1+F)N - # due to the inverse definition in attention.py - # 0 means good, 1 means nan data now - cur_mask_tiled = (1 - cur_mask_tiled.transpose(0, 1)).bool() # (1+F)N x BSF x 1 - mem_mask_tiled = (1 - mem_mask_tiled).bool() # BSF x (1+F)N x PN - tgt_mask_tiled = (1 - tgt_mask_tiled).bool() # BSF x (1+F)N x (1+F)N - - # expand mask to head dimensions - mem_mask_tiled = ( - mem_mask_tiled.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view(bs * fut_len * self.nhead, (1 + fut_len) * agent_num, PN) - ) # BSFH x (1+F)N x PN - tgt_mask_tiled = ( - tgt_mask_tiled.unsqueeze(1) - .repeat_interleave(self.nhead, dim=1) - .view( - bs * fut_len * self.nhead, - (1 + fut_len) * agent_num, - (1 + fut_len) * agent_num, - ) - ) # BSFH x (1+F)N x (1+F)N - tf_in_tiled = dec_in_z.repeat_interleave(fut_len, 1) * torch.logical_not( - cur_mask_tiled - ) - # input projection - tf_in_tiled = self.input_fc( - tf_in_tiled - ) # (F)N x BSF x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in_tiled = tf_in_tiled.masked_fill_( - cur_mask_tiled, float(0.0) - ) # (1+F)N x BS x feat - - # ******************************** transformer - - # add positional encoding - agent_enc_shuffle = ( - data["agent_enc_shuffle"] if self.agent_enc_shuffle else None - ) - tf_in_pos_tiled = self.pos_encoder( - tf_in_tiled, - num_a=agent_num, - agent_enc_shuffle=agent_enc_shuffle, - t_offset=self.past_frames - 1 if self.pos_offset else 0, - ) - # (F)N x BSF x feat, F is increamentally increased - - # optional: not masking will not affect the final results - # just to suppress some random numbers generated by linear layer's bias - # for cleaner printing, but these random numbers are not used later due to masking - tf_in_pos_tiled = tf_in_pos_tiled.masked_fill_(cur_mask_tiled, float(0.0)) - - # transformer decoder (between predicted steps and past context) - assert not torch.isnan(tf_in_pos_tiled).any(), "error" - context_tiled = context[:, :bs].repeat_interleave(fut_len, 1) - tf_out_tiled, _ = self.tf_decoder( - tf_in_pos_tiled, # (F)N x BSF x feat - context_tiled, # PN x BSF x feat - memory_mask=mem_mask_tiled, # BH x (F)N x PN - tgt_mask=tgt_mask_tiled, # BH x (F)N x (F)N - num_agent=agent_num, - need_weights=False, - ) - - assert not torch.isnan(tf_out_tiled).any(), "error" - # tf_out: (1+F)N x BSF x feat - - # ******************************** output projection - - # convert the output feature to output dimension (x, y) - if self.out_mlp_dim is not None: - out_tmp_tiled = self.out_mlp(tf_out_tiled) # (F)N x BSF x feat - seq_out_tiled = self.out_fc(out_tmp_tiled) - # select the diagonal of the output - seq_out_tiled = seq_out_tiled.reshape([fut_len + 1, agent_num, bs, fut_len, -1]) - seq_out_diag = torch.diagonal(seq_out_tiled, dim1=0, dim2=3).permute( - 1, 3, 0, 2 - ) # bs x fut_len x agent_num x dim - seq_out_diag = seq_out_diag.reshape(bs * fut_len, agent_num, -1) - mask = ( - data["fut_mask"] - .transpose(1, 2)[..., None] - .reshape(bs * fut_len, agent_num, -1) - ) - seq_out_diag *= mask - xdim, udim = self.dyn.xdim, self.dyn.udim - if self.cfg.dist_obj == "state": - if self.cfg.dist_type == "gaussian": - mu_u = seq_out_diag[..., :udim] - logvar_u = seq_out_diag[..., udim : 2 * udim] - if self.cfg.output_varx: - logvar_x = seq_out_diag[..., 2 * udim : 2 * udim + xdim] - var_x = torch.exp(logvar_x) - scene_noise_M = seq_out_diag[..., 2 * udim + xdim :].reshape( - *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim - ) - else: - var_x = ( - torch.tensor(self.cfg.min_var_x, device=mu_u.device)[None, None] - * torch.ones( - [bs * fut_len, agent_num, self.dyn.xdim], device=mu_u.device - ) - * mask - ) - scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( - *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim - ) - K = scene_noise_M.reshape( - bs * fut_len, agent_num, -1, self.cfg.scene_var_dim - ) - dist = MADynGaussian(mu_u, torch.exp(logvar_u), var_x, K, self.dyn) - elif self.cfg.dist_type == "GMM": - seq_out_diag = seq_out_diag.reshape( - bs * fut_len, agent_num, self.GMM_M, -1 - ).transpose(1, 2) - mu_u = seq_out_diag[..., : self.forecast_dim] - logvar_u = seq_out_diag[..., self.forecast_dim : 2 * self.forecast_dim] - if self.cfg.output_varx: - logvar_x = seq_out_diag[..., 2 * udim : 2 * udim + xdim] - var_x = torch.exp(logvar_x) - scene_noise_M = seq_out_diag[..., 2 * udim + xdim :].reshape( - *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim - ) - else: - scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( - *seq_out_diag.shape[:-1], - self.forecast_dim, - self.cfg.scene_var_dim, - ) - var_x = torch.tensor(self.cfg.min_var_x, device=mu_u.device)[ - None, None, None - ] * torch.ones( - [bs * fut_len, self.GMM_M, agent_num, self.dyn.xdim], - device=mu_u.device, - ) - K = scene_noise_M - - tf_feature = tf_out_tiled.reshape( - [fut_len + 1, agent_num, bs, fut_len, -1] - ) - tf_feature_diag = ( - torch.diagonal(tf_feature, dim1=0, dim2=3) - .permute(1, 3, 0, 2) - .reshape(bs * fut_len, agent_num, -1) - ) - tf_feature_pooled = tf_feature_diag.max(1)[0] - logpi = self.GMM_pi_net(tf_feature_pooled) - pi = torch.softmax(logpi, -1) - dist = MADynGMM(mu_u, torch.exp(logvar_u), var_x, K, pi, self.dyn) - xp = state_seq[1:].permute(2, 0, 1, 3).reshape(bs * fut_len, agent_num, -1) - x0 = ( - state_seq[:fut_len] - .permute(2, 0, 1, 3) - .reshape(bs * fut_len, agent_num, -1) - ) - - log_prob = dist.get_log_likelihood(x0, xp, mask).reshape(bs, fut_len) - elif self.cfg.dist_obj == "input": - if self.cfg.dist_type == "gaussian": - mu_u = seq_out_diag[..., :udim] - logvar_u = seq_out_diag[..., udim : 2 * udim] - scene_noise_M = seq_out_diag[..., 2 * udim :].reshape( - *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim - ) - K = scene_noise_M.reshape( - bs * fut_len, agent_num, -1, self.cfg.scene_var_dim - ) - dist = MAGaussian(mu_u, torch.exp(logvar_u), K) - elif self.cfg.dist_type == "GMM": - seq_out_diag = seq_out_diag.reshape( - bs * fut_len, agent_num, self.GMM_M, -1 - ).transpose(1, 2) - mu_u = seq_out_diag[..., :udim] - logvar_u = seq_out_diag[..., udim : 2 * udim] - K = seq_out_diag[..., 2 * udim :].reshape( - *seq_out_diag.shape[:-1], udim, self.cfg.scene_var_dim - ) - tf_feature = tf_out_tiled.reshape( - [fut_len + 1, agent_num, bs, fut_len, -1] - ) - tf_feature_diag = ( - torch.diagonal(tf_feature, dim1=0, dim2=3) - .permute(1, 3, 0, 2) - .reshape(bs * fut_len, agent_num, -1) - ) - tf_feature_pooled = tf_feature_diag.max(1)[0] - logpi = self.GMM_pi_net(tf_feature_pooled) - pi = torch.softmax(logpi, -1) - dist = MAGMM(mu_u, torch.exp(logvar_u), K, pi) - - xp = state_seq[1:].permute(2, 0, 1, 3).reshape(bs * fut_len, agent_num, -1) - x0 = ( - state_seq[:fut_len] - .permute(2, 0, 1, 3) - .reshape(bs * fut_len, agent_num, -1) - ) - up = self.dyn.inverse_dyn(x0, xp) - log_prob = dist.get_log_likelihood(up, mask) - - data["log_prob"] = log_prob - data["dist"] = dist - - def forward( - self, - data, - mode, - sample_num=1, - autoregress=True, - z=None, - need_weights=False, - cond_idx=None, - temp=0.1, - predict=False, - gt_step=0, - ): - agent_num, bs = ( - data["fut_motion"].size(1), - data["fut_motion"].size(2), - ) - if mode == "train": - assert sample_num == 1 - # conditional input to the decoding process - context = data["context_enc"].repeat_interleave( - sample_num, dim=1 - ) # PN x BS x feat - - pre_motion = data["pre_motion"].repeat_interleave( - sample_num, dim=2 - ) # P x N X BS x 2 - fut_motion = data["fut_motion"].repeat_interleave( - sample_num, dim=2 - ) # F x N X BS x 2 - pre_motion_scene_norm = data["pre_motion_scene_norm"].repeat_interleave( - sample_num, dim=2 - ) # P x N x BS x 2 - fut_motion_scene_norm = data["fut_motion_scene_norm"].repeat_interleave( - sample_num, dim=2 - ) # F x N x BS x 2 - input_dict = dict( - pre_motion=pre_motion, - fut_motion=fut_motion, - pre_motion_scene_norm=pre_motion_scene_norm, - fut_motion_scene_norm=fut_motion_scene_norm, - ) - if self.pred_type == "vel": - input_dict["pre_vel"] = data["pre_vel"].repeat_interleave( - sample_num, dim=2 - ) # P x N x BS x 2 - input_dict["fut_vel"] = data["fut_vel"].repeat_interleave( - sample_num, dim=2 - ) # F x N x BS x 2 - elif self.pred_type in ["dynamic", "dynamic_var", "dynamic_AR"]: - traj_scale = data["traj_scale"] - pre_state = torch.cat( - ( - data["pre_motion"] * traj_scale, - torch.norm(data["pre_vel"], dim=-1, keepdim=True) * traj_scale, - data["pre_heading_raw"].transpose(0, 2).unsqueeze(-1), - ), - -1, - ) # P x N x B x 4 (unscaled) - fut_state = torch.cat( - ( - data["fut_motion"] * traj_scale, - torch.norm(data["fut_vel"], dim=-1, keepdim=True) * traj_scale, - data["fut_heading_raw"].transpose(0, 2).unsqueeze(-1), - ), - -1, - ) - - pre_state_vec = torch.cat( - (data["pre_motion"], data["pre_vel"], data["pre_heading_vec"]), -1 - ) # P x N x B x 6 (scaled) - fut_state_vec = torch.cat( - (data["fut_motion"], data["fut_vel"], data["fut_heading_vec"]), -1 - ) # F x N x B x 6 (scaled) - input_dict["curr_state"] = pre_state[-1].repeat_interleave( - sample_num, dim=1 - ) - input_dict["pre_state"] = pre_state.repeat_interleave(sample_num, dim=2) - input_dict["fut_state"] = fut_state.repeat_interleave(sample_num, dim=2) - input_dict["pre_state_vec"] = pre_state_vec.repeat_interleave( - sample_num, dim=2 - ) - input_dict["fut_state_vec"] = fut_state_vec.repeat_interleave( - sample_num, dim=2 - ) - - # p(z), compute prior distribution - if self.z_type != "None": - if mode == "infer": - prior_key = "p_z_dist_infer" - else: - prior_key = "q_z_dist" if "q_z_dist" in data else "p_z_dist" - - if self.learn_prior: - p_z_params0 = self.p_z_net(data["agent_context"]) - - h = data["agent_context"].repeat_interleave( - sample_num, dim=0 - ) # NBS x feat - p_z_params = self.p_z_net(h) # NBS x 64 - if self.z_type == "gaussian": - data["p_z_dist_infer"] = Normal(params=p_z_params) - data["p_z_dist"] = Normal(params=p_z_params0) - else: - data["p_z_dist_infer"] = Categorical(logits=p_z_params, temp=temp) - data["p_z_dist"] = Categorical(logits=p_z_params0, temp=temp) - else: - if self.z_type == "gaussian": - data[prior_key] = Normal( - mu=torch.zeros(pre_motion.shape[1], self.nz).to( - pre_motion.device - ), - logvar=torch.zeros(pre_motion.shape[1], self.nz).to( - pre_motion.device - ), - ) - else: - data[prior_key] = Categorical( - logits=torch.zeros(pre_motion.shape[1], self.nz).to( - pre_motion.device - ) - ) - - # sample latent code from the distribution - if z is None: - # use latent code z from posterior for training - - # use latent code z from posterior for evaluating the reconstruction loss - if mode == "recon": - z = data["q_z_dist"].mode() # NB x 32 - z = z.view(agent_num, bs, z.size(-1)) # N x B x 32 - - # use latent code z from the prior for inference - elif mode in ["infer", "train"]: - # dist = data["p_z_dist_infer"] if "p_z_dist_infer" in data else data["q_z_dist_infer"] - # z = dist.sample() # NBS x 32 - # import pdb - # pdb.set_trace() - - dist = ( - data["q_z_dist"] - if data["q_z_dist"] is not None - else data["p_z_dist"] - ) - if self.z_type == "gaussian": - if predict: - z = dist.pseudo_sample(sample_num) - else: - z = data["p_z_dist_infer"].sample() - D = z.shape[-1] - samples = z.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3) - mu = dist.mu.reshape(agent_num, bs, -1, D).permute(1, 0, 2, 3)[ - :, :, 0 - ] - sigma = dist.sigma.reshape(agent_num, bs, -1, D).permute( - 1, 0, 2, 3 - )[:, :, 0] - data["prob"] = self.pseudo_sample_prob( - samples, mu, sigma, data["agent_avail"] - ) - elif self.z_type == "discrete": - if predict: - z = dist.pseudo_sample(sample_num).contiguous() - else: - z = dist.rsample(sample_num).contiguous() - D = z.shape[-1] - idx = z.argmax(dim=-1) - prob_sample = torch.gather(dist.probs, -1, idx) - prob_sample = prob_sample.reshape(agent_num, bs, -1).mean(0) - prob_sample = prob_sample / prob_sample.sum(-1, keepdim=True) - data["prob"] = prob_sample - - z = z.view(agent_num, bs * sample_num, z.size(-1)) # N x BS x 32 - else: - raise ValueError("Unknown Mode!") - else: - z = None - - # trajectory decoding - # if mode=="train": - # self.calc_traj_likelihood(data, context, input_dict, z) - # elif mode=="infer": - # self.calc_traj_likelihood(data, context, input_dict, z) - # self.decode_traj_ar( - # data, - # mode, - # context, - # input_dict, - # z, - # sample_num, - # gt_step = gt_step, - # need_weights=need_weights, - # ) - - # else: - # raise NotImplementedError - self.calc_traj_likelihood(data, context, input_dict, z) - self.decode_traj_ar( - data, - mode, - context, - input_dict, - z, - sample_num, - gt_step=gt_step, - need_weights=need_weights, - ) - - def pseudo_sample_prob(self, sample, mu, sigma, mask): - """ - A simple K-means estimation to estimate the probability of samples - """ - bs, Na, Ns, D = sample.shape - device = sample.device - Np = Ns * 50 - particle = torch.randn([bs, Na, Np, D]).to(device) * sigma.unsqueeze( - -2 - ) + mu.unsqueeze(-2) - dis = torch.linalg.norm(sample.unsqueeze(-2) - particle.unsqueeze(-3), dim=-1) - dis = (dis * mask[..., None, None]).sum(1) - idx = torch.argmin(dis, -2) - flag = idx.unsqueeze(1) == torch.arange(Ns).view(1, Ns, 1).repeat_interleave( - bs, 0 - ).to(device) - prob = flag.sum(-1) / Np - return prob - - -""" AgentFormer """ - - -class AgentFormer(nn.Module): - def __init__(self, cfg): - super().__init__() - - self.cfg = cfg - - input_type = cfg.input_type - pred_type = cfg.pred_type - if type(input_type) == str: - input_type = [input_type] - fut_input_type = cfg.fut_input_type - dec_input_type = cfg.dec_input_type - - self.use_map = cfg.use_map - self.rand_rot_scene = cfg.rand_rot_scene - self.discrete_rot = cfg.discrete_rot - self.map_global_rot = cfg.map_global_rot - self.ar_train = cfg.ar_train - self.max_train_agent = cfg.max_train_agent - self.loss_cfg = cfg.loss_cfg - self.param_annealers = nn.ModuleList() - self.z_type = cfg.z_type - if self.z_type == "discrete": - z_tau_annealer = ExpParamAnnealer( - cfg.z_tau.start, cfg.z_tau.finish, cfg.z_tau.decay - ) - self.param_annealers.append(z_tau_annealer) - self.z_tau_annealer = z_tau_annealer - - self.ego_conditioning = cfg.ego_conditioning - self.step_time = cfg.step_time - self.dyn = dynamics.Unicycle(cfg.step_time) - self.DoubleIntegrator = dynamics.DoubleIntegrator(cfg.step_time) - if "perturb" in cfg and cfg.perturb.enabled: - self.N_pert = cfg.perturb.N_pert - theta = cfg.perturb.OU.theta - sigma = cfg.perturb.OU.sigma - scale = torch.tensor(cfg.perturb.OU.scale) - self.pert = DynOrnsteinUhlenbeckPerturbation( - theta * torch.ones(self.dyn.udim), sigma * scale, self.dyn - ) - else: - self.N_pert = 0 - self.pert = None - if "stage" in cfg: - assert cfg.stage * cfg.num_frames_per_stage <= cfg.future_num_frames - self.stage = cfg.stage - self.num_frames_per_stage = cfg.num_frames_per_stage - else: - self.stage = 1 - self.num_frames_per_stage = cfg.future_num_frames - - # save all computed variables - self.data = dict() - - # map encoder - if self.use_map: - self.map_encoder = base_models.RasterizedMapEncoder( - model_arch=cfg.map_encoder.model_architecture, - input_image_shape=cfg.map_encoder.image_shape, - feature_dim=cfg.map_encoder.feature_dim, - use_spatial_softmax=cfg.map_encoder.spatial_softmax.enabled, - spatial_softmax_kwargs=cfg.map_encoder.spatial_softmax.kwargs, - ) - - # models - self.context_encoder = ContextEncoder(cfg) - self.future_encoder = FutureEncoder(cfg) - self.future_decoder = FutureDecoder(cfg) - - def set_data(self, batch, stage=0): - device = batch["pre_motion_raw"].device - self.data[stage] = batch - self.data[stage]["step_time"] = self.step_time - bs, Na = batch["pre_motion_raw"].shape[:2] - self.data[stage]["pre_motion"] = ( - batch["pre_motion_raw"].to(device).transpose(0, 2).contiguous() - ) # P x N x B x 2 - self.data[stage]["fut_motion"] = ( - batch["fut_motion_raw"].to(device).transpose(0, 2).contiguous() - ) # F x N x B x 2 - - # compute the origin of the current scene, i.e., the center - # of the agents' location in the current frame - self.data[stage]["scene_orig"] = torch.nanmean( - self.data[stage]["pre_motion"][-1], dim=0 - ) # B x 2 - - # normalize the scene with respect to the center location - # optionally, also rotate the scene for augmentation - if self.rand_rot_scene and self.training: - # below cannot be fixed in seed, causing reproducibility issue - if self.discrete_rot: - theta = torch.randint(high=24, size=(1,)).to(device) * (np.pi / 12) - else: - theta = torch.rand(1).to(device) * np.pi * 2 # [0, 2*pi], full circle - - for key in ["pre_motion", "fut_motion"]: - ( - self.data[stage][f"{key}"], - self.data[stage][f"{key}_scene_norm"], - ) = rotation_2d_torch( - self.data[stage][key], theta, self.data[stage]["scene_orig"] - ) - if self.data[stage]["heading"] is not None: - self.data[stage]["heading"] += theta # B x N - else: - theta = torch.zeros(1).to(device) - - # normalize per scene - for key in ["pre_motion", "fut_motion"]: # (F or P) x N x B x 2 - self.data[stage][f"{key}_scene_norm"] = ( - self.data[stage][key] - self.data[stage]["scene_orig"] - ) - - # normalize pos per agent - self.data[stage]["cur_motion"] = self.data[stage]["pre_motion"][ - [-1] - ] # 1 x N x B x 2 - self.data[stage]["pre_motion_norm"] = ( - self.data[stage]["pre_motion"][:-1] - - self.data[stage]["cur_motion"] # P x N x B x 2 - ) - self.data[stage]["fut_motion_norm"] = ( - self.data[stage]["fut_motion"] - self.data[stage]["cur_motion"] - ) # F x N x B x 2 - - # vectorize heading - if self.data[stage]["heading"] is not None: - self.data[stage]["heading_vec"] = torch.stack( - [ - torch.cos(self.data[stage]["heading"]), - torch.sin(self.data[stage]["heading"]), - ], - dim=-1, - ).transpose(0, 1) - # N x B x 2 - self.data[stage]["pre_heading_vec"] = torch.stack( - [ - torch.cos(self.data[stage]["pre_heading_raw"]), - torch.sin(self.data[stage]["pre_heading_raw"]), - ], - dim=-1, - ).transpose(0, 2) - # P x N x B x 2 - - self.data[stage]["fut_heading_vec"] = torch.stack( - [ - torch.cos(self.data[stage]["fut_heading_raw"]), - torch.sin(self.data[stage]["fut_heading_raw"]), - ], - dim=-1, - ).transpose(0, 2) - # F x N x B x 2 - - # agent shuffling, default not shuffling - if self.training and self.cfg["agent_enc_shuffle"]: - self.data[stage]["agent_enc_shuffle"] = torch.randperm( - self.cfg["max_agent_len"] - )[: self.data[stage]["agent_num"]].to(device) - else: - self.data[stage]["agent_enc_shuffle"] = None - - # mask between pairwse agents, such as diable connection for a pair of agents - # that are far away from each other, currently not used, i.e., assuming all connections - conn_dist = self.cfg.conn_dist - cur_motion = self.data[stage]["cur_motion"][0] - if conn_dist < 1000.0: - threshold = conn_dist / self.cfg.traj_scale - pdist = F.pdist(cur_motion) - D = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) - D[np.triu_indices(cur_motion.shape[0], 1)] = pdist - D += D.T - mask = torch.zeros_like(D) - mask[D > threshold] = float("-inf") - else: - mask = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) - self.data[stage][ - "agent_mask" - ] = mask # N x N, all zeros now, i.e., fully-connected - - def step_annealer(self): - for anl in self.param_annealers: - anl.step() - - def convert_data(self, batch, cond_traj=None, predict=False): - data = defaultdict(lambda: None) - ego_traj = torch.cat((batch["fut_pos"][:, 0], batch["fut_yaw"][:, 0]), -1) - external_cond = True if cond_traj is not None else False - if cond_traj is None and not predict: - if self.pert is not None: - # always perturb the ego trajectory - - ego_traj_tiled = ego_traj.repeat_interleave(self.N_pert, 0) - avail = batch["fut_mask"][:, 0].repeat_interleave(self.N_pert, 0) - pert_dict = self.pert.perturb( - dict( - fut_pos=ego_traj_tiled[..., :2], - fut_yaw=ego_traj_tiled[..., 2:], - fut_mask=avail, - step_time=self.step_time, - ) - ) - pert_ego_positions = pert_dict["fut_pos"] - pert_ego_trajectories = torch.cat( - (pert_dict["fut_pos"], pert_dict["fut_yaw"]), -1 - ) - - cond_traj = torch.cat( - ( - ego_traj.unsqueeze(1), - TensorUtils.reshape_dimensions_single( - pert_ego_trajectories, 0, 1, [-1, self.N_pert] - ), - ), - 1, - ) - else: - cond_traj = ego_traj.unsqueeze(1) - - device = batch["hist_pos"].device - bs = batch["hist_yaw"].shape[0] - data["heading"] = batch["hist_yaw"][:, :, -1, 0].to(device) # B x N - data["pre_heading_raw"] = batch["hist_yaw"][..., 0].to(device) # B x N x P - data["fut_heading_full"] = batch["fut_yaw"][..., 0].to(device) - data["fut_heading_raw"] = data["fut_heading_full"][ - ..., : self.num_frames_per_stage - ] # B x N x F - traj_scale = self.cfg.traj_scale - data["traj_scale"] = traj_scale - # AgentFormer uses the x/y inputs, i.e., the first two dimensions - data["pre_motion_raw"] = (batch["hist_pos"] / traj_scale).to( - device - ) # B x N x P x 2 - data["fut_motion_full"] = (batch["fut_pos"] / traj_scale).to(device) - data["fut_motion_raw"] = ( - batch["fut_pos"][:, :, : self.num_frames_per_stage] / traj_scale - ).to( - device - ) # B x N x F x 2 - - data["pre_mask"] = ( - batch["hist_mask"].float().to(device) - ) # B x N x P # B x N x F x 2 - data["fut_mask_full"] = batch["fut_mask"].float().to(device) # B x N x F - data["fut_mask"] = data["fut_mask_full"][..., : self.num_frames_per_stage] - data["agent_avail"] = data["pre_mask"].any(-1).float() - data["image"] = batch["image"] - if cond_traj is not None and self.ego_conditioning: - Ne = cond_traj.shape[1] - for k in [ - "heading", - "pre_motion_raw", - "fut_motion_full", - "pre_mask", - "fut_mask_full", - "fut_mask", - "agent_avail", - "image", - "pre_heading_raw", - "fut_heading_raw", - "fut_heading_full", - ]: - data[k] = ( - data[k].repeat_interleave(Ne, 0) if data[k] is not None else None - ) - - fut_motion_full = data["fut_motion_full"] * traj_scale - cond_traj_tiled = TensorUtils.join_dimensions(cond_traj, 0, 2) - fut_motion_full[:, 0] = cond_traj_tiled[..., :2] - data["fut_heading_full"][:, 0] = cond_traj_tiled[..., 2] - data["fut_heading_raw"] = data["fut_heading_full"][ - ..., : self.num_frames_per_stage - ] - data["fut_motion_full"] = fut_motion_full / traj_scale - data["fut_motion_raw"] = data["fut_motion_full"][ - :, :, : self.num_frames_per_stage - ] - - if self.ego_conditioning: - data["cond_traj"] = cond_traj - else: - data["cond_traj"] = None - data["pre_vel"] = self.DoubleIntegrator.calculate_vel( - data["pre_motion_raw"], None, data["pre_mask"].bool() - ) - data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() - data["fut_vel"] = self.DoubleIntegrator.calculate_vel( - data["fut_motion_raw"], None, data["fut_mask"].bool() - ) # F x N x B x 2 - data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() - - return data - - def gen_data_stage(self, batch, pred_traj, stage): - if stage == 0: - return batch - else: - data = defaultdict(lambda: None) - device = pred_traj.device - # fields that does not change - for k in [ - "traj_scale", - "agent_enc_shuffle", - "fut_motion_full", - "fut_mask_full", - ]: - data[k] = batch[k] - traj_scale = self.cfg.traj_scale - bs, M, Na = pred_traj.shape[:3] - data["heading"] = batch["heading"].repeat_interleave(M, 0) # (B*M) x N - - Ts = self.num_frames_per_stage - P = self.cfg.history_num_frames - F = self.cfg.future_num_frames - if Ts < P: - # left over from previous stage - - prev_stage_hist_pos = batch["pre_motion_raw"][ - :, :, Ts - P : - ].repeat_interleave( - M, 0 - ) # (B*M) x N x (P-Ts) x 2 - prev_stage_hist_yaw = batch["pre_heading_raw"][ - :, :, Ts - P : - ].repeat_interleave( - M, 0 - ) # (B*M) x N x (P-Ts) - new_hist_pos = TensorUtils.join_dimensions( - pred_traj[:, :, :, :Ts, :2], 0, 2 - ) # (B*M) x N x Ts x 2 - new_hist_yaw = TensorUtils.join_dimensions( - pred_traj[:, :, :, :Ts, 2], 0, 2 - ) # (B*M) x N x Ts - - data["pre_motion_raw"] = torch.cat( - (prev_stage_hist_pos, new_hist_pos), 2 - ) # (B*M) x N x P x 2 - data["pre_heading_raw"] = torch.cat( - (prev_stage_hist_yaw, new_hist_yaw), 2 - ) # (B*M) x N x P - - prev_stage_pre_mask = batch["pre_mask"][ - :, :, Ts - P : - ].repeat_interleave( - M, 0 - ) # (B*M) x N x (P-Ts) - # since this is associated with the predicted trajectory, all entries is True except for dummy agents - new_stage_pre_mask = ( - batch["agent_avail"] - .unsqueeze(-1) - .repeat_interleave(M, 0) - .repeat_interleave(Ts, -1) - ) # (B*M) x N x Ts - data["pre_mask"] = torch.cat( - (prev_stage_pre_mask, new_stage_pre_mask), -1 - ) # (B*M) x N x P - else: - data["pre_motion_raw"] = TensorUtils.join_dimensions( - pred_traj[:, :, :, -P:, :2], 0, 2 - ) # (B*M) x N x P x 2 - data["pre_heading_raw"] = TensorUtils.join_dimensions( - pred_traj[:, :, :, -P:], 0, 2 - ) # (B*M) x N x P - data["pre_mask"] = ( - batch["agent_avail"] - .unsqueeze(-1) - .repeat_interleave(M, 0) - .repeat_interleave(P, -1) - ) # (B*M) x N x P - # for future motion, pad the unknown future with 0 - - data["fut_motion_raw"] = batch["fut_motion_full"][ - ..., stage * Ts : (stage + 1) * Ts, : - ].repeat_interleave( - M**stage, 0 - ) # (B*M) x N x Ts x 2 - data["fut_heading_raw"] = batch["fut_heading_full"][ - ..., stage * Ts : (stage + 1) * Ts - ].repeat_interleave( - M**stage, 0 - ) # (B*M) x N x Ts - - data["fut_mask"] = batch["fut_mask_full"][ - ..., stage * Ts : (stage + 1) * Ts - ].repeat_interleave(M**stage, 0) - - data["agent_avail"] = batch["agent_avail"].repeat_interleave(M, 0) - - data["pre_vel"] = self.DoubleIntegrator.calculate_vel( - data["pre_motion_raw"], None, data["pre_mask"].bool() - ) - data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() - data["fut_vel"] = self.DoubleIntegrator.calculate_vel( - data["fut_motion_raw"], None, data["fut_mask"].bool() - ) # F x N x B x 2 - data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() - if data["map_enc"] is not None: - data["map_enc"] = batch["map_enc"].repeat_interleave( - M, 1 - ) # N x (B*M) x D - - return data - - def forward(self, batch, sample_k=None, predict=False, **kwargs): - cond_traj = kwargs["cond_traj"] if "cond_traj" in kwargs else None - data0 = self.convert_data(batch, cond_traj=cond_traj, predict=predict) - pred_traj = None - pred_batch = dict() - pred_batch["p_z_dist"] = dict() - pred_batch["q_z_dist"] = dict() - if self.ego_conditioning: - cond_idx = [0] - else: - cond_idx = None - data_stage = data0 - for stage in range(self.stage): - data_stage = self.gen_data_stage(data_stage, pred_traj, stage) - self.set_data(data_stage, stage) - pred_data = self.run_model( - stage, sample_k, predict=predict, cond_idx=cond_idx - ) - pred_traj = pred_data["infer_dec_motion"] - if "infer_dec_state" not in pred_data: - yaws = torch.zeros_like(pred_traj[..., 0:1]) - else: - yaws = pred_data["infer_dec_state"][..., 3:] - - pred_traj = torch.cat((pred_traj, yaws), -1) - pred_batch["p_z_dist"][stage] = pred_data["p_z_dist"] - pred_batch["q_z_dist"][stage] = pred_data["q_z_dist"] - - positions, state, var, controls = self.batching_multistage_traj() - positions = positions * self.cfg.traj_scale - NeB, numMode, Na, F = positions.shape[:4] - bs = batch["hist_pos"].shape[0] - Ne = int(NeB / bs) - if state is None: - yaws = batch["hist_yaw"][:, :, [-1]].repeat_interleave(F, 2) - - yaws = ( - yaws.unsqueeze(1).repeat_interleave(Ne, 0).repeat_interleave(numMode, 1) - ) - trajectories = torch.cat((positions, yaws), -1) - else: - trajectories = state[..., [0, 1, 3]] - if "prob" not in self.data[0]: - prob = ( - torch.ones(trajectories.shape[:2]).to(trajectories.device) - / trajectories.shape[1] - ) - prob = prob / prob.sum(-1, keepdim=True) - else: - M = int(numMode ** (1 / self.stage)) - prob = self.data[self.stage - 1]["prob"].reshape( - bs * Ne, *([M] * self.stage) - ) - - for stage in range(self.stage - 1): - desired_shape = ( - [bs * Ne] + [M] * (stage + 1) + [1] * (self.stage - stage - 1) - ) - prob = prob * TensorUtils.reshape_dimensions( - self.data[stage]["prob"], 0, 2, desired_shape - ) - prob = TensorUtils.join_dimensions(prob, 1, self.stage + 1) - - pred_except_dist = dict( - trajectories=trajectories, - state_trajectory=state, - p=prob, - fut_pos=data0["fut_motion_full"] * self.cfg.traj_scale, - ) - pred_except_dist = TensorUtils.reshape_dimensions( - pred_except_dist, 0, 1, [bs, Ne] - ) - pred_batch.update(pred_except_dist) - if controls is not None: - pred_batch["controls"] = controls - pred_batch["cond_traj"] = data0["cond_traj"] - agent_avail = self.data[0]["agent_avail"] - agent_avail = agent_avail.reshape([bs, Ne, -1])[:, 0] - pred_batch["agent_avail"] = agent_avail - pred_batch.update(self._traj_to_preds(pred_batch["trajectories"])) - if var is not None: - pred_batch["var"] = var - if not predict: - self.step_annealer() - else: - pred_batch = {k: v for k, v in pred_batch.items() if "dist" not in k} - pred_batch["data_batch"] = batch - del data0 - return pred_batch - - def batching_multistage_traj(self): - if "infer_dec_motion" in self.data[0]: - infer_traj = list() - bs, M = self.data[0]["infer_dec_motion"].shape[:2] - for stage in range(self.stage): - traj_i = self.data[stage]["infer_dec_motion"].repeat_interleave( - (M ** (self.stage - stage - 1)), 0 - ) - traj_i = traj_i.reshape(bs, M**self.stage, *traj_i.shape[2:]) - infer_traj.append(traj_i) - infer_traj = torch.cat(infer_traj, -2) - else: - infer_traj = None - if "infer_dec_state" in self.data[0]: - infer_state = list() - bs, M = self.data[0]["infer_dec_state"].shape[:2] - for stage in range(self.stage): - state_i = self.data[stage]["infer_dec_state"].repeat_interleave( - (M ** (self.stage - stage - 1)), 0 - ) - state_i = state_i.reshape(bs, M**self.stage, *state_i.shape[2:]) - infer_state.append(state_i) - infer_state = torch.cat(infer_state, -2) - else: - infer_state = None - if "infer_var" in self.data[0]: - infer_var = list() - bs, M = self.data[0]["infer_var"].shape[:2] - for stage in range(self.stage): - var_i = self.data[stage]["infer_var"].repeat_interleave( - (M ** (self.stage - stage - 1)), 0 - ) - var_i = var_i.reshape(bs, M**self.stage, *var_i.shape[2:]) - infer_var.append(var_i) - infer_var = torch.cat(infer_var, -2) - else: - infer_var = None - - if "controls" in self.data[0]: - controls = list() - bs, M = self.data[0]["controls"].shape[:2] - for stage in range(self.stage): - controls_i = self.data[stage]["controls"].repeat_interleave( - (M ** (self.stage - stage - 1)), 0 - ) - controls_i = controls_i.reshape( - bs, M**self.stage, *controls_i.shape[2:] - ) - controls.append(controls_i) - controls = torch.cat(controls, -2) - else: - controls = None - - return infer_traj, infer_state, infer_var, controls - - def sample(self, batch, sample_k): - return self.forward(batch, sample_k) - - def run_model(self, stage, sample_k=None, predict=False, cond_idx=None): - if self.use_map and self.data[stage]["map_enc"] is None: - image = self.data[0]["image"] - bs, Na = image.shape[:2] - map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) - map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) - self.data[stage]["map_enc"] = map_enc.transpose(0, 1) - self.context_encoder(self.data[stage]) - if not predict: - self.future_encoder(self.data[stage]) - # self.future_decoder(self.data[stage], mode='train', autoregress=self.ar_train) - - if sample_k is None: - self.inference( - sample_num=self.cfg.sample_k, - stage=stage, - cond_idx=cond_idx, - predict=predict, - ) - else: - self.inference( - sample_num=sample_k, stage=stage, cond_idx=cond_idx, predict=predict - ) - - # self.data[stage]["cond_traj"] = None - return self.data[stage] - - def inference( - self, - mode="infer", - sample_num=20, - need_weights=False, - stage=0, - cond_idx=None, - predict=False, - ): - if self.use_map and self.data[stage]["map_enc"] is None: - image = self.data[0]["image"] - bs, Na = image.shape[:2] - map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) - map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) - self.data[stage]["map_enc"] = map_enc.transpose(0, 1) - if self.data[stage]["context_enc"] is None: - self.context_encoder(self.data[stage]) - if mode == "recon": - sample_num = 1 - self.future_encoder(self.data[stage], temp=self.z_tau_annealer.val()) - - if self.z_type == "gaussian": - temp = None - else: - temp = 0.0001 if predict else self.z_tau_annealer.val() - # raise Exception("one of p and q need to exist") - - self.future_decoder( - self.data[stage], - mode=mode, - sample_num=sample_num, - autoregress=True, - need_weights=need_weights, - cond_idx=cond_idx, - temp=temp, - predict=predict, - ) - return self.data[stage][f"{mode}_dec_motion"], self.data - - def _traj_to_preds(self, traj): - pred_positions = traj[..., :2] - pred_yaws = traj[..., 2:3] - return { - "trajectories": traj, - "predictions": {"positions": pred_positions, "yaws": pred_yaws}, - } - - def compute_losses(self, pred_batch, data_batch): - if "data_batch" in pred_batch: - data_batch = pred_batch["data_batch"] - device = pred_batch["trajectories"].device - bs, Ne, numMode, Na = pred_batch["trajectories"].shape[:4] - pred_batch["trajectories"] = pred_batch["trajectories"].nan_to_num(0) - M = int(numMode ** (1 / self.stage)) - kl_loss = torch.tensor(0.0, device=device) - if "q_z_dist" in pred_batch and "p_z_dist" in pred_batch: - for stage in range(self.stage): - kl_loss += ( - pred_batch["q_z_dist"][stage] - .kl(pred_batch["p_z_dist"][stage]) - .nan_to_num(0) - .sum(-1) - .mean() - ) - - kl_loss = kl_loss.clamp_min_(self.cfg.loss_cfg.kld.min_clip) - - traj_pred_tiled = TensorUtils.join_dimensions(pred_batch["trajectories"], 0, 2) - traj_pred_tiled2 = TensorUtils.join_dimensions(traj_pred_tiled, 0, 2) - cond_traj = pred_batch["cond_traj"] - if Ne > 1: - fut_mask = data_batch["fut_mask"].repeat_interleave(Ne, 0) - else: - fut_mask = data_batch["fut_mask"] - - pred_loss, goal_loss = MultiModal_trajectory_loss( - predictions=traj_pred_tiled[..., :2], - targets=TensorUtils.join_dimensions(pred_batch["fut_pos"], 0, 2), - availabilities=fut_mask, - prob=TensorUtils.join_dimensions(pred_batch["p"], 0, 2), - calc_goal_reach=False, - ) - extent = data_batch["extent"][..., :2] - div_score = diversity_score( - traj_pred_tiled[..., :2], - fut_mask.unsqueeze(1).repeat_interleave(numMode, 1).any(1), - ) - # cond_extent = extent[torch.arange(bs),pred_batch["cond_idx"]] - if pred_batch["cond_traj"] is not None: - if "EC_coll_loss" in pred_batch: - EC_coll_loss = pred_batch["EC_coll_loss"] - else: - EC_edges, type_mask = batch_utils().gen_EC_edges( - traj_pred_tiled2[:, 1:], - cond_traj.reshape(bs * Ne, 1, -1, 3) - .repeat_interleave(numMode, 0) - .repeat_interleave(Na - 1, 1), - extent[:, 0].repeat_interleave(Ne * numMode, 0), - extent[:, 1:].repeat_interleave(Ne * numMode, 0), - data_batch["type"][:, 1:].repeat_interleave(Ne * numMode, 0), - pred_batch["agent_avail"].repeat_interleave(Ne * numMode, 0)[:, 1:], - ) - - EC_edges = TensorUtils.reshape_dimensions( - EC_edges, 0, 1, (bs, Ne, numMode) - ) - type_mask = TensorUtils.reshape_dimensions( - type_mask, 0, 1, (bs, Ne, numMode) - ) - prob = pred_batch["p"] - EC_coll_loss = collision_loss_masked( - EC_edges, type_mask, weight=prob.reshape(bs, Ne, -1).unsqueeze(-1) - ) - if not isinstance(EC_coll_loss, torch.Tensor): - EC_coll_loss = torch.tensor(EC_coll_loss).to(device) - else: - EC_coll_loss = torch.tensor(0.0).to(device) - - # compute collision loss - - pred_edges = batch_utils().generate_edges( - pred_batch["agent_avail"].repeat_interleave(numMode * Ne, 0), - extent.repeat_interleave(Ne * numMode, 0), - traj_pred_tiled2[..., :2], - traj_pred_tiled2[..., 2:], - ) - - coll_loss = collision_loss(pred_edges=pred_edges) - if not isinstance(coll_loss, torch.Tensor): - coll_loss = torch.tensor(coll_loss).to(device) - - losses = OrderedDict( - prediction_loss=pred_loss, - kl_loss=kl_loss, - collision_loss=coll_loss, - EC_collision_loss=EC_coll_loss, - diversity_loss=-div_score, - ) - - if "controls" in pred_batch: - acce_reg_loss = (pred_batch["controls"][..., 0] ** 2).mean() - steering_reg_loss = (pred_batch["controls"][..., 1] ** 2).mean() - losses["acce_reg_loss"] = acce_reg_loss - losses["steering_reg_loss"] = steering_reg_loss - - # if self.cfg.input_weight_scaling is not None and "controls" in pred_batch: - # input_weight_scaling = torch.tensor(self.cfg.input_weight_scaling).to(pred_batch["controls"].device) - # losses["input_loss"] = torch.mean(pred_batch["controls"] ** 2 *pred_batch["mask"][...,None]*input_weight_scaling) - - return losses - - -class ARAgentFormer(nn.Module): - def __init__(self, cfg): - super().__init__() - - self.cfg = cfg - - input_type = cfg.input_type - pred_type = cfg.pred_type - if type(input_type) == str: - input_type = [input_type] - fut_input_type = cfg.fut_input_type - dec_input_type = cfg.dec_input_type - - self.use_map = cfg.use_map - self.rand_rot_scene = cfg.rand_rot_scene - self.discrete_rot = cfg.discrete_rot - self.map_global_rot = cfg.map_global_rot - self.ar_train = cfg.ar_train - self.max_train_agent = cfg.max_train_agent - self.loss_cfg = cfg.loss_cfg - self.param_annealers = nn.ModuleList() - self.z_type = cfg.z_type - if self.z_type == "discrete": - z_tau_annealer = ExpParamAnnealer( - cfg.z_tau.start, cfg.z_tau.finish, cfg.z_tau.decay - ) - self.param_annealers.append(z_tau_annealer) - self.z_tau_annealer = z_tau_annealer - # if "gt_step_anneal_length" in cfg and cfg.gt_step_anneal_length>0: - # self.gt_step_annealer = IntegerParamAnnealer(cfg.future_num_frames-1,0,cfg.gt_step_anneal_length) - # self.param_annealers.append(self.gt_step_annealer) - self.step_time = cfg.step_time - self.dyn = dynamics.Unicycle(cfg.step_time) - self.DoubleIntegrator = dynamics.DoubleIntegrator(cfg.step_time) - - # save all computed variables - self.data = dict() - - # map encoder - if self.use_map: - self.map_encoder = base_models.RasterizedMapEncoder( - model_arch=cfg.map_encoder.model_architecture, - input_image_shape=cfg.map_encoder.image_shape, - feature_dim=cfg.map_encoder.feature_dim, - use_spatial_softmax=cfg.map_encoder.spatial_softmax.enabled, - spatial_softmax_kwargs=cfg.map_encoder.spatial_softmax.kwargs, - ) - - # models - self.context_encoder = ContextEncoder(cfg) - self.future_encoder = FutureEncoder(cfg) - self.future_decoder = FutureARDecoder(cfg) - - self.future_num_frames = cfg.future_num_frames - self.history_num_frames = cfg.history_num_frames - - def set_data(self, batch): - device = batch["pre_motion_raw"].device - self.data = batch - self.data["step_time"] = self.step_time - bs, Na = batch["pre_motion_raw"].shape[:2] - self.data["pre_motion"] = ( - batch["pre_motion_raw"].to(device).transpose(0, 2).contiguous() - ) # P x N x B x 2 - self.data["fut_motion"] = ( - batch["fut_motion_raw"].to(device).transpose(0, 2).contiguous() - ) # F x N x B x 2 - - # compute the origin of the current scene, i.e., the center - # of the agents' location in the current frame - self.data["scene_orig"] = torch.nanmean( - self.data["pre_motion"][-1], dim=0 - ) # B x 2 - - # normalize the scene with respect to the center location - # optionally, also rotate the scene for augmentation - if self.rand_rot_scene and self.training: - # below cannot be fixed in seed, causing reproducibility issue - if self.discrete_rot: - theta = torch.randint(high=24, size=(1,)).to(device) * (np.pi / 12) - else: - theta = torch.rand(1).to(device) * np.pi * 2 # [0, 2*pi], full circle - - for key in ["pre_motion", "fut_motion"]: - ( - self.data[f"{key}"], - self.data[f"{key}_scene_norm"], - ) = rotation_2d_torch(self.data[key], theta, self.data["scene_orig"]) - if self.data["heading"] is not None: - self.data["heading"] += theta # B x N - else: - theta = torch.zeros(1).to(device) - - # normalize per scene - for key in ["pre_motion", "fut_motion"]: # (F or P) x N x B x 2 - self.data[f"{key}_scene_norm"] = ( - self.data[key] - self.data["scene_orig"] - ) - - # normalize pos per agent - self.data["cur_motion"] = self.data["pre_motion"][[-1]] # 1 x N x B x 2 - self.data["pre_motion_norm"] = ( - self.data["pre_motion"][:-1] - self.data["cur_motion"] # P x N x B x 2 - ) - self.data["fut_motion_norm"] = ( - self.data["fut_motion"] - self.data["cur_motion"] - ) # F x N x B x 2 - - # vectorize heading - if self.data["heading"] is not None: - self.data["heading_vec"] = torch.stack( - [torch.cos(self.data["heading"]), torch.sin(self.data["heading"])], - dim=-1, - ).transpose(0, 1) - # N x B x 2 - self.data["pre_heading_vec"] = torch.stack( - [ - torch.cos(self.data["pre_heading_raw"]), - torch.sin(self.data["pre_heading_raw"]), - ], - dim=-1, - ).transpose(0, 2) - # P x N x B x 2 - - self.data["fut_heading_vec"] = torch.stack( - [ - torch.cos(self.data["fut_heading_raw"]), - torch.sin(self.data["fut_heading_raw"]), - ], - dim=-1, - ).transpose(0, 2) - # F x N x B x 2 - - # agent shuffling, default not shuffling - if self.training and self.cfg["agent_enc_shuffle"]: - self.data["agent_enc_shuffle"] = torch.randperm(self.cfg["max_agent_len"])[ - : self.data["agent_num"] - ].to(device) - else: - self.data["agent_enc_shuffle"] = None - - # mask between pairwse agents, such as diable connection for a pair of agents - # that are far away from each other, currently not used, i.e., assuming all connections - conn_dist = self.cfg.conn_dist - cur_motion = self.data["cur_motion"][0] - if conn_dist < 1000.0: - threshold = conn_dist / self.cfg.traj_scale - pdist = F.pdist(cur_motion) - D = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) - D[np.triu_indices(cur_motion.shape[0], 1)] = pdist - D += D.T - mask = torch.zeros_like(D) - mask[D > threshold] = float("-inf") - else: - mask = torch.zeros([cur_motion.shape[0], cur_motion.shape[0]]).to(device) - self.data["agent_mask"] = mask # N x N, all zeros now, i.e., fully-connected - - def step_annealer(self): - for anl in self.param_annealers: - anl.step() - - def convert_data(self, batch): - data = defaultdict(lambda: None) - - device = batch["hist_pos"].device - bs = batch["hist_yaw"].shape[0] - data["heading"] = batch["hist_yaw"][:, :, -1, 0].to(device) # B x N - data["pre_heading_raw"] = batch["hist_yaw"][..., 0].to(device) # B x N x P - data["fut_heading_full"] = batch["fut_yaw"][..., 0].to(device) - data["fut_heading_raw"] = data["fut_heading_full"][ - ..., : self.future_num_frames - ] # B x N x F - traj_scale = self.cfg.traj_scale - data["traj_scale"] = traj_scale - # AgentFormer uses the x/y inputs, i.e., the first two dimensions - data["pre_motion_raw"] = (batch["hist_pos"] / traj_scale).to( - device - ) # B x N x P x 2 - data["fut_motion_full"] = (batch["fut_pos"] / traj_scale).to(device) - data["fut_motion_raw"] = ( - batch["fut_pos"][:, :, : self.future_num_frames] / traj_scale - ).to( - device - ) # B x N x F x 2 - - data["pre_mask"] = ( - batch["hist_mask"].float().to(device) - ) # B x N x P # B x N x F x 2 - data["fut_mask_full"] = batch["fut_mask"].float().to(device) # B x N x F - data["fut_mask"] = data["fut_mask_full"][..., : self.future_num_frames] - data["agent_avail"] = data["pre_mask"].any(-1).float() - data["image"] = batch["image"] - - data["pre_vel"] = self.DoubleIntegrator.calculate_vel( - data["pre_motion_raw"], None, data["pre_mask"].bool() - ) - data["pre_vel"] = data["pre_vel"].transpose(0, 2).contiguous() - data["fut_vel"] = self.DoubleIntegrator.calculate_vel( - data["fut_motion_raw"], None, data["fut_mask"].bool() - ) # F x N x B x 2 - data["fut_vel"] = data["fut_vel"].transpose(0, 2).contiguous() - - return data - - def forward(self, batch, sample_k=None, predict=False, **kwargs): - data = self.convert_data(batch) - pred_batch = dict() - pred_batch["p_z_dist"] = dict() - pred_batch["q_z_dist"] = dict() - - self.set_data(data) - pred_data = self.run_model(sample_k, predict=predict) - - mode = "infer" if predict else "train" - if mode == "infer": - yaws = pred_data[f"{mode}_dec_state"][..., 3:] - - pred_batch["p_z_dist"] = pred_data["p_z_dist"] - pred_batch["q_z_dist"] = pred_data["q_z_dist"] - - positions = self.data[f"{mode}_dec_motion"] - state = self.data[f"{mode}_dec_state"] - controls = self.data["controls"] - positions = positions * self.cfg.traj_scale - bs, numMode, Na, F = positions.shape[:4] - if state is None: - yaws = batch["hist_yaw"][:, :, [-1]].repeat_interleave(F, 2) - - yaws = yaws.unsqueeze(1).repeat_interleave(numMode, 1) - trajectories = torch.cat((positions, yaws), -1) - else: - trajectories = state[..., [0, 1, 3]] - if "prob" not in self.data: - prob = ( - torch.ones(trajectories.shape[:2]).to(trajectories.device) - / trajectories.shape[1] - ) - prob = prob / prob.sum(-1, keepdim=True) - else: - prob = self.data["prob"].reshape(bs, -1) - - pred_except_dist = dict( - trajectories=trajectories, - state_trajectory=state, - p=prob, - fut_pos=self.data["fut_motion_full"] * self.cfg.traj_scale, - ) - pred_batch.update(pred_except_dist) - if controls is not None: - pred_batch["controls"] = controls - - pred_batch["agent_avail"] = self.data["agent_avail"] - pred_batch.update(self._traj_to_preds(pred_batch["trajectories"])) - pred_batch = {k: v for k, v in pred_batch.items() if "dist" not in k} - pred_batch["data_batch"] = batch - return pred_batch - else: - self.step_annealer() - return self.data - - def sample(self, batch, sample_k): - return self.forward(batch, sample_k) - - def run_model(self, sample_k=None, predict=False): - if self.use_map and self.data["map_enc"] is None: - image = self.data["image"] - bs, Na = image.shape[:2] - map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) - map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) - self.data["map_enc"] = map_enc.transpose(0, 1) - self.context_encoder(self.data) - mode = "infer" if predict else "train" - if mode == "infer": - if sample_k is None: - self.inference(sample_num=self.cfg.sample_k, predict=predict, mode=mode) - else: - self.inference(sample_num=sample_k, predict=predict, mode=mode) - else: - self.train_model(mode=mode) - return self.data - - def inference(self, mode="infer", sample_num=20, need_weights=False, predict=False): - if self.use_map and self.data["map_enc"] is None: - image = self.data["image"] - bs, Na = image.shape[:2] - map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) - map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) - self.data["map_enc"] = map_enc.transpose(0, 1) - if self.data["context_enc"] is None: - self.context_encoder(self.data) - - if self.z_type == "gaussian": - temp = None - else: - temp = 0.0001 if predict else self.z_tau_annealer.val() - # raise Exception("one of p and q need to exist") - - self.future_decoder( - self.data, - mode=mode, - sample_num=sample_num, - autoregress=True, - need_weights=need_weights, - temp=temp, - predict=predict, - gt_step=0, - ) - return self.data[f"{mode}_dec_motion"], self.data - - def train_model(self, mode="train", need_weights=False): - if self.use_map and self.data["map_enc"] is None: - image = self.data["image"] - bs, Na = image.shape[:2] - map_enc = self.map_encoder(TensorUtils.join_dimensions(image, 0, 2)) - map_enc = map_enc.reshape(bs, Na, *map_enc.shape[1:]) - self.data["map_enc"] = map_enc.transpose(0, 1) - if self.data["context_enc"] is None: - self.context_encoder(self.data) - - if self.z_type == "discrete": - temp = self.z_tau_annealer.val() - else: - temp = None - # raise Exception("one of p and q need to exist") - - self.future_decoder( - self.data, - mode=mode, - sample_num=1, - autoregress=True, - need_weights=need_weights, - temp=temp, - ) - - return self.data - - def _traj_to_preds(self, traj): - pred_positions = traj[..., :2] - pred_yaws = traj[..., 2:3] - return { - "trajectories": traj, - "predictions": {"positions": pred_positions, "yaws": pred_yaws}, - } - - def compute_losses(self, pred_batch, data_batch): - losses = dict() - if "log_prob" in pred_batch: - log_prob_loss = -pred_batch["log_prob"].mean() - losses["log_prob"] = log_prob_loss - if "q_z_dist" in pred_batch and pred_batch["q_z_dist"] is not None: - kl_loss = ( - pred_batch["q_z_dist"] - .kl(pred_batch["p_z_dist"]) - .nan_to_num(0) - .sum(-1) - .mean() - ) - losses["kl_loss"] = kl_loss - - if "trajectories" in pred_batch: - # inference mode - if "data_batch" in pred_batch: - data_batch = pred_batch["data_batch"] - device = pred_batch["trajectories"].device - bs, numMode, Na = pred_batch["trajectories"].shape[:3] - pred_batch["trajectories"] = pred_batch["trajectories"].nan_to_num(0) - - traj_pred = pred_batch["trajectories"] - traj_pred_tiled2 = TensorUtils.join_dimensions(traj_pred, 0, 2) - - fut_mask = data_batch["fut_mask"] - - pred_loss, goal_loss = MultiModal_trajectory_loss( - predictions=traj_pred[..., :2], - targets=pred_batch["fut_pos"], - availabilities=fut_mask, - prob=pred_batch["p"], - calc_goal_reach=False, - ) - extent = data_batch["extent"][..., :2] - div_score = diversity_score( - traj_pred[..., :2], - fut_mask.unsqueeze(1).repeat_interleave(numMode, 1).any(1), - ) - - # compute collision loss - - pred_edges = batch_utils().generate_edges( - pred_batch["agent_avail"].repeat_interleave(numMode, 0), - extent.repeat_interleave(numMode, 0), - traj_pred_tiled2[..., :2], - traj_pred_tiled2[..., 2:], - ) - - coll_loss = collision_loss(pred_edges=pred_edges) - if not isinstance(coll_loss, torch.Tensor): - coll_loss = torch.tensor(coll_loss).to(device) - - pred_losses = OrderedDict( - prediction_loss=pred_loss, - collision_loss=coll_loss, - diversity_loss=-div_score, - ) - # if "controls" in pred_batch: - # scale = self.cfg.loss_weights.input_loss_scale if "input_loss_scale" in self.cfg.loss_weights else 1.0 - # acce_reg_loss = (pred_batch["controls"][...,0]**2).mean() - # steering_reg_loss = (pred_batch["controls"][...,1]**2).mean() - # pred_losses["acce_reg_loss"] = acce_reg_loss*scale - # pred_losses["steering_reg_loss"] = steering_reg_loss*scale - # pred_losses["acce_jerk_loss"] = torch.mean((pred_batch["controls"][...,1:,0]-pred_batch["controls"][...,:-1,0])**2)*scale - # pred_losses["steering_jerk_loss"] = torch.mean((pred_batch["controls"][...,1:,1]-pred_batch["controls"][...,:-1,1])**2)*scale - losses.update(pred_losses) - - # if self.cfg.input_weight_scaling is not None and "controls" in pred_batch: - # input_weight_scaling = torch.tensor(self.cfg.input_weight_scaling).to(pred_batch["controls"].device) - # losses["input_loss"] = torch.mean(pred_batch["controls"] ** 2 *pred_batch["mask"][...,None]*input_weight_scaling) - - return losses diff --git a/diffstack/models/agentformer_lib.py b/diffstack/models/agentformer_lib.py deleted file mode 100644 index 2e7eebb..0000000 --- a/diffstack/models/agentformer_lib.py +++ /dev/null @@ -1,1044 +0,0 @@ -""" -Modified version of PyTorch Transformer module for the implementation of Agent-Aware Attention (L290-L308) -""" - -from typing import Callable, Dict, Final, List, Optional, Set, Tuple, Union -import warnings -import math -import copy - -import torch -from torch import Tensor -import torch.nn as nn -from torch.nn import functional as F -from torch.nn.functional import * -from torch.nn.modules.module import Module -from torch.nn.modules.activation import MultiheadAttention -from torch.nn.modules.container import ModuleList -from torch.nn.init import xavier_uniform_ -from torch.nn.modules.dropout import Dropout -from torch.nn.modules.linear import Linear -from torch.nn.modules.normalization import LayerNorm -from torch.nn.init import xavier_uniform_ -from torch.nn.init import constant_ -from torch.nn.init import xavier_normal_ -from torch.nn.parameter import Parameter -from torch.overrides import has_torch_function, handle_torch_function -from torchvision import models - -def compute_z_kld(data, cfg): - loss_unweighted = data['q_z_dist_dlow'].kl(data['p_z_dist_infer']).sum() - if cfg.get('normalize', True): - loss_unweighted /= data['batch_size'] - loss_unweighted = loss_unweighted.clamp_min_(cfg.min_clip) - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - -def diversity_loss(data, cfg): - loss_unweighted = 0 - fut_motions = data['infer_dec_motion'].view(*data['infer_dec_motion'].shape[:2], -1) - for motion in fut_motions: - dist = F.pdist(motion, 2) ** 2 - loss_unweighted += (-dist / cfg['d_scale']).exp().mean() - if cfg.get('normalize', True): - loss_unweighted /= data['batch_size'] - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - -def recon_loss(data, cfg): - diff = data['infer_dec_motion'] - data['fut_motion_orig'].unsqueeze(1) - if cfg.get('mask', True): - mask = data['fut_mask'].unsqueeze(1).unsqueeze(-1) - diff *= mask - dist = diff.pow(2).sum(dim=-1).sum(dim=-1) - loss_unweighted = dist.min(dim=1)[0] - if cfg.get('normalize', True): - loss_unweighted = loss_unweighted.mean() - else: - loss_unweighted = loss_unweighted.sum() - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - - -# """ DLow (Diversifying Latent Flows)""" -# class DLow(nn.Module): -# def __init__(self, cfg): -# super().__init__() - -# self.device = torch.device('cpu') -# self.cfg = cfg -# self.nk = nk = cfg.sample_k -# self.nz = nz = cfg.nz -# self.share_eps = cfg.get('share_eps', True) -# self.train_w_mean = cfg.get('train_w_mean', False) -# self.loss_cfg = self.cfg.loss_cfg -# self.loss_names = list(self.loss_cfg.keys()) - -# pred_cfg = Config(cfg.pred_cfg, cfg.train_tag, tmp=False, create_dirs=False) -# pred_model = model_lib.model_dict[pred_cfg.model_id](pred_cfg) -# self.pred_model_dim = pred_cfg.tf_model_dim -# if cfg.pred_epoch > 0: -# cp_path = pred_cfg.model_path % cfg.pred_epoch -# print('loading model from checkpoint: %s' % cp_path) -# model_cp = torch.load(cp_path, map_location='cpu') -# pred_model.load_state_dict(model_cp['model_dict']) -# pred_model.eval() -# self.pred_model = [pred_model] - -# # Dlow's Q net -# self.qnet_mlp = cfg.get('qnet_mlp', [512, 256]) -# self.q_mlp = MLP(self.pred_model_dim, self.qnet_mlp) -# self.q_A = nn.Linear(self.q_mlp.out_dim, nk * nz) -# self.q_b = nn.Linear(self.q_mlp.out_dim, nk * nz) - -# def set_device(self, device): -# self.device = device -# self.to(device) -# self.pred_model[0].set_device(device) - -# def set_data(self, data): -# self.pred_model[0].set_data(data) -# self.data = self.pred_model[0].data - -# def main(self, mean=False, need_weights=False): -# pred_model = self.pred_model[0] -# if hasattr(pred_model, 'use_map') and pred_model.use_map: -# self.data['map_enc'] = pred_model.map_encoder(self.data['agent_maps']) -# pred_model.context_encoder(self.data) - -# if not mean: -# if self.share_eps: -# eps = torch.randn([1, self.nz]).to(self.device) -# eps = eps.repeat((self.data['agent_num'] * self.nk, 1)) -# else: -# eps = torch.randn([self.data['agent_num'], self.nz]).to(self.device) -# eps = eps.repeat_interleave(self.nk, dim=0) - -# qnet_h = self.q_mlp(self.data['agent_context']) -# A = self.q_A(qnet_h).view(-1, self.nz) -# b = self.q_b(qnet_h).view(-1, self.nz) - -# z = b if mean else A*eps + b -# logvar = (A ** 2 + 1e-8).log() -# self.data['q_z_dist_dlow'] = Normal(mu=b, logvar=logvar) - -# pred_model.future_decoder(self.data, mode='infer', sample_num=self.nk, autoregress=True, z=z, need_weights=need_weights) -# return self.data - -# def forward(self): -# return self.main(mean=self.train_w_mean) - -# def inference(self, mode, sample_num, need_weights=False): -# self.main(mean=True, need_weights=need_weights) -# res = self.data[f'infer_dec_motion'] -# if mode == 'recon': -# res = res[:, 0] -# return res, self.data - -# def compute_loss(self): -# total_loss = 0 -# loss_dict = {} -# loss_unweighted_dict = {} -# for loss_name in self.loss_names: -# loss, loss_unweighted = loss_func[loss_name](self.data, self.loss_cfg[loss_name]) -# total_loss += loss -# loss_dict[loss_name] = loss.item() -# loss_unweighted_dict[loss_name] = loss_unweighted.item() -# return total_loss, loss_dict, loss_unweighted_dict - -# def step_annealer(self): -# pass - -def agent_aware_attention(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None, - gaussian_kernel = True, - num_agent = 1, - in_proj_weight_self = None, - in_proj_bias_self = None - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - See "Attention Is All You Need" for more details. - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - bias_k, bias_v: bias of the key and value sequences to be added at dim=0. - add_zero_attn: add a new batch of zeros to the key and - value sequences at dim=1. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - use_separate_proj_weight: the function accept the proj. weights for query, key, - and value in different forms. If false, in_proj_weight will be used, which is - a combination of q_proj_weight, k_proj_weight, v_proj_weight. - q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. - static_k, static_v: static key and value used for attention operators. - - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, - N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, - N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, - out_proj_weight, out_proj_bias) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_forward, tens_ops, query, key, value, - embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, - bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias, training=training, key_padding_mask=key_padding_mask, - need_weights=need_weights, attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) - tgt_len, bs, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - # allow MHA to have different sizes for the feature dimension - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - # # replace nan with 0, so that we can use to check if it is a self-attention vs cross-attention - # query_no_nan = torch.nan_to_num(query.clone()) - # key_no_nan = torch.nan_to_num(key.clone()) - # value_no_nan = torch.nan_to_num(value.clone()) - - if not use_separate_proj_weight: - if torch.equal(query, key) and torch.equal(key, value): # PN x B X feat - # if torch.equal(query_no_nan, key_no_nan) and torch.equal(key_no_nan, value_no_nan): # PN x B X feat - - q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) # PN x B x feat - if in_proj_weight_self is not None: - q_self, k_self = linear(query, in_proj_weight_self, in_proj_bias_self).chunk(2, dim=-1) - - # elif torch.equal(key_no_nan, value_no_nan): - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = linear(query, _w, _b) - - if key is None: - assert value is None - k = None - v = None - else: - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = linear(key, _w, _b).chunk(2, dim=-1) - - if in_proj_weight_self is not None: - _w = in_proj_weight_self[:embed_dim, :] - _b = in_proj_bias_self[:embed_dim] - q_self = linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _w = in_proj_weight_self[embed_dim:, :] - _b = in_proj_bias_self[embed_dim:] - k_self = linear(key, _w, _b) - - else: - raise NotImplementedError - - else: - q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) - len1, len2 = q_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == query.size(-1) - - k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) - len1, len2 = k_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == key.size(-1) - - v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) - len1, len2 = v_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == value.size(-1) - - if in_proj_bias is not None: - q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) - k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) - v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) - else: - q = linear(query, q_proj_weight_non_opt, in_proj_bias) - k = linear(key, k_proj_weight_non_opt, in_proj_bias) - v = linear(value, v_proj_weight_non_opt, in_proj_bias) - # k, q, v has PN x B X feat, q maybe FN x B x feat - - # default gaussian_kernel = False - if not gaussian_kernel: - q = q * scaling # remove scaling - if in_proj_weight_self is not None: - q_self = q_self * scaling # remove scaling - - if attn_mask is not None: - assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ - attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ - 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) - if attn_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [bs * num_heads, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') - else: - raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") - key_padding_mask = key_padding_mask.to(torch.bool) - - if bias_k is not None and bias_v is not None: - if static_k is None and static_v is None: - k = torch.cat([k, bias_k.repeat(1, bs, 1)]) - v = torch.cat([v, bias_v.repeat(1, bs, 1)]) - if attn_mask is not None: - attn_mask = pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = pad(key_padding_mask, (0, 1)) - else: - assert static_k is None, "bias cannot be added to static key." - assert static_v is None, "bias cannot be added to static value." - else: - assert bias_k is None - assert bias_v is None - - q = q.contiguous().view(tgt_len, bs * num_heads, head_dim).transpose(0, 1) # BH x PN x feat - if k is not None: - k = k.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) # BH x PN x feat - if v is not None: - v = v.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) - if in_proj_weight_self is not None: - q_self = q_self.contiguous().view(tgt_len, bs * num_heads, head_dim).transpose(0, 1) - k_self = k_self.contiguous().view(-1, bs * num_heads, head_dim).transpose(0, 1) - - if static_k is not None: - assert static_k.size(0) == bs * num_heads - assert static_k.size(2) == head_dim - k = static_k - - if static_v is not None: - assert static_v.size(0) == bs * num_heads - assert static_v.size(2) == head_dim - v = static_v - - src_len = k.size(1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bs - assert key_padding_mask.size(1) == src_len - - if add_zero_attn: - src_len += 1 - k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) - if attn_mask is not None: - attn_mask = pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = pad(key_padding_mask, (0, 1)) - - if gaussian_kernel: - qk = torch.bmm(q, k.transpose(1, 2)) - q_n = q.pow(2).sum(dim=-1).unsqueeze(-1) - k_n = k.pow(2).sum(dim=-1).unsqueeze(1) - qk_dist = q_n + k_n - 2 * qk - attn_output_weights = qk_dist * scaling * 0.5 - else: - attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # BH x PN x PN, or BH x FN x PN - # attention weights contain random numbers for the timestamps without data - - assert list(attn_output_weights.size()) == [bs * num_heads, tgt_len, src_len] - - if in_proj_weight_self is not None: - """ - ================================== - Agent-Aware Attention - ================================== - """ - attn_output_weights_inter = attn_output_weights # BH x PN x PN, or BH x FN x PN - attn_output_weights_self = torch.bmm(q_self, k_self.transpose(1, 2)) # BH x PN x PN, or BH x FN x PN - - # using identity matrix here since the agents are not shuffled - attn_weight_self_mask = torch.eye(num_agent).to(q.device) - attn_weight_self_mask = attn_weight_self_mask.repeat([attn_output_weights.shape[1] // num_agent, attn_output_weights.shape[2] // num_agent]).unsqueeze(0) - # 1 x PN x PN - - attn_output_weights = attn_output_weights_inter * (1 - attn_weight_self_mask) + attn_output_weights_self * attn_weight_self_mask # BH x PN x PN - - # masking the columns / rows - if attn_mask is not None: # BH x PN x PN or BH x FN x PN - if attn_mask.dtype == torch.bool: - - # assign -inf so that the columns of the invalid data will lead to 0 after softmax during attention - # this is to disable the interaction between valid agents in rows and invalid agents in some columns - attn_output_weights.masked_fill_(attn_mask, float('-inf')) # BH x PN x PN or BH x FN x PN - - # however, the rows with invalid data (with all -inf now) will lead to NaN after softmax - # because there is no single valid data for that row - # as a result, it will lead to NaN in backward during softmax - # we need to assign some dummy numbers to it, this process is needed for training - # but this does not affect the results in forward pass since these rows of features are not used - attn_output_weights.masked_fill_(attn_mask[:, :, [0]], 0.0) - else: - attn_output_weights += attn_mask # BH x PN x PN or BH x FN x PN - - attn_output_weights = softmax(attn_output_weights, dim=-1) # BH x PN x PN or BH x FN x PN - # the output attn_output_weights will have 0.17 for the rows with invalid data - # because the entire row prior to softmax is 0, so it is averaged over the row - - # to suppress the random number in the row with invalida data - # we mask again with 0s, however in-place operation not supported for backward - # attn_output_weights.masked_fill_(attn_mask[[0], :, [0]].unsqueeze(-1), 0.0) - else: - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float('-inf')) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view(bs, num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), - ) - attn_output_weights = attn_output_weights.view(bs * num_heads, tgt_len, src_len) - attn_output_weights = softmax( - attn_output_weights, dim=-1) - - # attn_output_weights is row-wise, i.e., the agent at a timestamp without valid data (NaN) has some random numbers - # for the columns, the agent at a timestamp without valid data is 0 - # add torch.nan_to_num to convert NaN to a large number for the agent value without valid data - # but when the 0 in attn_output_weights col * the large number in v, it will result in 0 - # in other words, we do not attend to the agent timestamp with NaN data - # the output might have some invalid rows (rows with random numbers but do not affect results) - attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) # BH x PN x PN, or BH x FN x PN - attn_output = torch.bmm(attn_output_weights, torch.nan_to_num(v, nan=1e+10)) # BH x PN x feat, or BH x FN x feat - - # to maintain elegancy, we mask those invalid rows with random numbers as 0s - # but not masking will not affect results - # final_mask = attn_mask[:, :, [0]] # 1 x PN x 1 - # attn_output = attn_output.masked_fill_(final_mask, 0.0) # BH x PN x feat - - # convert to output shape - assert list(attn_output.size()) == [bs * num_heads, tgt_len, head_dim] - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bs, embed_dim) # PN x B x feat - attn_output = linear(attn_output, out_proj_weight, out_proj_bias) # PN x B x feat - - # average attention weights over heads - if need_weights: - attn_output_weights = attn_output_weights.view(bs, num_heads, tgt_len, src_len) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class AgentAwareAttention(Module): - r"""Allows the model to jointly attend to information - from different representation subspaces. - See reference: Attention Is All You Need - - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - bias: add bias as module parameter. Default: True. - add_bias_kv: add bias to the key and value sequences at dim=0. - add_zero_attn: add a new batch of zeros to the key and - value sequences at dim=1. - kdim: total number of features in key. Default: None. - vdim: total number of features in value. Default: None. - - Note: if kdim and vdim are None, they will be set to embed_dim such that - query, key, and value have the same number of features. - - Examples:: - - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - """ - bias_k: Optional[torch.Tensor] - bias_v: Optional[torch.Tensor] - - def __init__(self, cfg, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): - super().__init__() - self.cfg = cfg - self.gaussian_kernel = self.cfg.get('gaussian_kernel', False) - self.sep_attn = self.cfg.get('sep_attn', True) - self.embed_dim = embed_dim - self.kdim = kdim if kdim is not None else embed_dim - self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim - - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - - if self._qkv_same_embed_dim is False: - self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) - self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) - self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) - self.register_parameter('in_proj_weight', None) - else: - self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) - self.register_parameter('q_proj_weight', None) - self.register_parameter('k_proj_weight', None) - self.register_parameter('v_proj_weight', None) - - if bias: - self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) - else: - self.register_parameter('in_proj_bias', None) - self.out_proj = Linear(embed_dim, embed_dim) - - if add_bias_kv: - self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) - self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) - else: - self.bias_k = self.bias_v = None - - self.add_zero_attn = add_zero_attn - - if self.sep_attn: - self.in_proj_weight_self = Parameter(torch.empty(2 * embed_dim, embed_dim)) - self.in_proj_bias_self = Parameter(torch.empty(2 * embed_dim)) - else: - self.in_proj_weight_self = self.in_proj_bias_self = None - - self._reset_parameters() - - def _reset_parameters(self): - if self._qkv_same_embed_dim: - xavier_uniform_(self.in_proj_weight) - else: - xavier_uniform_(self.q_proj_weight) - xavier_uniform_(self.k_proj_weight) - xavier_uniform_(self.v_proj_weight) - - if self.in_proj_bias is not None: - constant_(self.in_proj_bias, 0.) - constant_(self.out_proj.bias, 0.) - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - if self.sep_attn: - xavier_uniform_(self.in_proj_weight_self) - constant_(self.in_proj_bias_self, 0.) - - def __setstate__(self, state): - # Support loading old MultiheadAttention checkpoints generated by v1.1.0 - if '_qkv_same_embed_dim' not in state: - state['_qkv_same_embed_dim'] = True - - super().__setstate__(state) - - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None, num_agent=1): - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - See "Attention Is All You Need" for more details. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - if not self._qkv_same_embed_dim: - return agent_aware_attention( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, gaussian_kernel=self.gaussian_kernel, - num_agent=num_agent, - in_proj_weight_self=self.in_proj_weight_self, - in_proj_bias_self=self.in_proj_bias_self - ) - else: - return agent_aware_attention( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, gaussian_kernel=self.gaussian_kernel, - num_agent=num_agent, - in_proj_weight_self=self.in_proj_weight_self, - in_proj_bias_self=self.in_proj_bias_self - ) - - -class AgentFormerEncoderLayer(Module): - r"""TransformerEncoderLayer is made up of self-attn and feedforward network. - This standard encoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, - Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in - Neural Information Processing Systems, pages 6000-6010. Users may modify or implement - in a different way during application. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__(self, cfg, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): - super().__init__() - self.cfg = cfg - self.self_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = Linear(d_model, dim_feedforward) - self.dropout = Dropout(dropout) - self.linear2 = Linear(dim_feedforward, d_model) - - self.norm1 = LayerNorm(d_model) - self.norm2 = LayerNorm(d_model) - self.dropout1 = Dropout(dropout) - self.dropout2 = Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def __setstate__(self, state): - if 'activation' not in state: - state['activation'] = F.relu - super().__setstate__(state) - - def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1) -> torch.Tensor: - r"""Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - src2 = self.self_attn(src, src, src, attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, num_agent=num_agent)[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - return src - - -class AgentFormerDecoderLayer(Module): - r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. - This standard decoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, - Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in - Neural Information Processing Systems, pages 6000-6010. Users may modify or implement - in a different way during application. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__(self, cfg, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): - super().__init__() - self.cfg = cfg - self.self_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) - self.multihead_attn = AgentAwareAttention(cfg, d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = Linear(d_model, dim_feedforward) - self.dropout = Dropout(dropout) - self.linear2 = Linear(dim_feedforward, d_model) - - self.norm1 = LayerNorm(d_model) - self.norm2 = LayerNorm(d_model) - self.norm3 = LayerNorm(d_model) - self.dropout1 = Dropout(dropout) - self.dropout2 = Dropout(dropout) - self.dropout3 = Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def __setstate__(self, state): - if 'activation' not in state: - state['activation'] = F.relu - super().__setstate__(state) - - def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, num_agent = 1, need_weights = False) -> torch.Tensor: - r"""Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, num_agent=num_agent, need_weights=need_weights) - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2, cross_attn_weights = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, num_agent=num_agent, need_weights=need_weights) - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt, self_attn_weights, cross_attn_weights - - -class AgentFormerEncoder(Module): - r"""TransformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the TransformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - __constants__ = ['norm'] - - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward(self, src: torch.Tensor, mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1) -> torch.Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - output = src - - for mod in self.layers: - output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, num_agent=num_agent) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class AgentFormerDecoder(Module): - r"""TransformerDecoder is a stack of N decoder layers - - Args: - decoder_layer: an instance of the TransformerDecoderLayer() class (required). - num_layers: the number of sub-decoder-layers in the decoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = transformer_decoder(tgt, memory) - """ - __constants__ = ['norm'] - - def __init__(self, decoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, num_agent=1, need_weights = False) -> torch.Tensor: - r"""Pass the inputs (and mask) through the decoder layer in turn. - - Args: - tgt: the sequence to the decoder (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - output = tgt - - self_attn_weights = [None] * len(self.layers) - cross_attn_weights = [None] * len(self.layers) - for i, mod in enumerate(self.layers): - output, self_attn_weights[i], cross_attn_weights[i] = mod(output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - num_agent=num_agent, need_weights=need_weights) - - if self.norm is not None: - output = self.norm(output) - - if need_weights: - self_attn_weights = torch.stack(self_attn_weights).cpu().numpy() - cross_attn_weights = torch.stack(cross_attn_weights).cpu().numpy() - - return output, {'self_attn_weights': self_attn_weights, 'cross_attn_weights': cross_attn_weights} - - -def _get_clones(module, N): - return ModuleList([copy.deepcopy(module) for i in range(N)]) - - -def _get_activation_fn(activation): - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - -class MapCNN(nn.Module): - def __init__(self, cfg): - super().__init__() - self.convs = nn.ModuleList() - map_channels = cfg.get('map_channels', 3) - patch_size = cfg.get('patch_size', [100, 100]) - hdim = cfg.get('hdim', [32, 32]) - kernels = cfg.get('kernels', [3, 3]) - strides = cfg.get('strides', [3, 3]) - self.out_dim = out_dim = cfg.get('out_dim', 32) - self.input_size = input_size = (map_channels, patch_size[0], patch_size[1]) - x_dummy = torch.randn(input_size).unsqueeze(0) - - for i, _ in enumerate(hdim): - self.convs.append(nn.Conv2d(map_channels if i == 0 else hdim[i-1], - hdim[i], kernels[i], - stride=strides[i])) - x_dummy = self.convs[i](x_dummy) - - self.fc = nn.Linear(x_dummy.numel(), out_dim) - - def forward(self, x): - for conv in self.convs: - x = F.leaky_relu(conv(x), 0.2) - x = torch.flatten(x, start_dim=1) - x = self.fc(x) - return x - -class MapEncoder(nn.Module): - def __init__(self, cfg): - super().__init__() - model_id = cfg.get('model_id', 'map_cnn') - dropout = cfg.get('dropout', 0.0) - self.normalize = cfg.get('normalize', True) - self.dropout = nn.Dropout(dropout) - if model_id == 'map_cnn': - self.model = MapCNN(cfg) - self.out_dim = self.model.out_dim - elif 'resnet' in model_id: - model_dict = { - 'resnet18': models.resnet18, - 'resnet34': models.resnet34, - 'resnet50': models.resnet50 - } - self.out_dim = out_dim = cfg.get('out_dim', 32) - self.model = model_dict[model_id](pretrained=False, norm_layer=nn.InstanceNorm2d) - self.model.fc = nn.Linear(self.model.fc.in_features, out_dim) - else: - raise ValueError('unknown map encoder!') - - def forward(self, x): - if self.normalize: - x = x * 2. - 1. - x = self.model(x) - x = self.dropout(x) - return x - -def compute_motion_mse(data, cfg): - diff = data['fut_motion_orig'] - data['train_dec_motion'] - # print(data['fut_motion_orig']) - # print(data['train_dec_motion']) - # zxc - if cfg.get('mask', True): - mask = data['fut_mask'] - diff *= mask.unsqueeze(2) - loss_unweighted = diff.pow(2).sum() - if cfg.get('normalize', True): - loss_unweighted /= diff.shape[0] - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - -def compute_z_kld(data, cfg): - loss_unweighted = data['q_z_dist'].kl(data['p_z_dist']).sum() - if cfg.get('normalize', True): - loss_unweighted /= data['batch_size'] - loss_unweighted = loss_unweighted.clamp_min_(cfg.min_clip) - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - -def compute_sample_loss(data, cfg): - diff = data['infer_dec_motion'] - data['fut_motion_orig'].unsqueeze(1) - if cfg.get('mask', True): - mask = data['fut_mask'].unsqueeze(1).unsqueeze(-1) - diff *= mask - dist = diff.pow(2).sum(dim=-1).sum(dim=-1) - loss_unweighted = dist.min(dim=1)[0] - if cfg.get('normalize', True): - loss_unweighted = loss_unweighted.mean() - else: - loss_unweighted = loss_unweighted.sum() - loss = loss_unweighted * cfg['weight'] - return loss, loss_unweighted - - -loss_func = { - 'mse': compute_motion_mse, - 'kld': compute_z_kld, - 'sample': compute_sample_loss -} \ No newline at end of file diff --git a/diffstack/modules/predictors/factory.py b/diffstack/modules/predictors/factory.py index d3e87db..3ceb7fb 100644 --- a/diffstack/modules/predictors/factory.py +++ b/diffstack/modules/predictors/factory.py @@ -4,9 +4,7 @@ from diffstack.utils.utils import removeprefix from diffstack.modules.predictors.kinematic_predictor import KinematicTreeModel -from diffstack.modules.predictors.tbsim_predictors import ( - AgentFormerTrafficModel, -) + from diffstack.modules.predictors.CTT import CTTTrafficModel @@ -35,14 +33,6 @@ def predictor_factory( model_registrar, config, logger, device, input_mappings=input_mappings ) - elif algo_name in [ - "agentformer_multistage", - "agentformer_singlestage", - "agentformer", - ]: - predictor = AgentFormerTrafficModel( - model_registrar, config, logger, device, input_mappings=input_mappings - ) elif algo_name == "CTT": predictor = CTTTrafficModel( model_registrar, config, logger, device, input_mappings=input_mappings diff --git a/diffstack/modules/predictors/tbsim_predictors.py b/diffstack/modules/predictors/tbsim_predictors.py deleted file mode 100644 index 21a6c4e..0000000 --- a/diffstack/modules/predictors/tbsim_predictors.py +++ /dev/null @@ -1,477 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from diffstack.modules.module import Module, DataFormat, RunMode - -from diffstack.utils.utils import traj_xyh_to_xyhv, removeprefix -import diffstack.utils.tensor_utils as TensorUtils -import diffstack.utils.geometry_utils as GeoUtils -from diffstack.utils.batch_utils import batch_utils -from diffstack.models.agentformer import AgentFormer -from trajdata.data_structures.batch import SceneBatch -from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D -import diffstack.utils.tensor_utils as TensorUtils -from typing import Dict, Any - - -from diffstack.utils.loss_utils import ( - collision_loss, -) - - -from trajdata.data_structures import StateTensor, AgentType - - -from diffstack.utils.lane_utils import SimpleLaneRelation -from diffstack.utils.homotopy import ( - identify_pairwise_homotopy, -) -import diffstack.utils.metrics as Metrics - - -class AgentFormerTrafficModel(Module): - @property - def input_format(self) -> DataFormat: - return DataFormat(["scene_batch"]) - - @property - def output_format(self) -> DataFormat: - return DataFormat( - [ - "mixed_pred_ml:validate", - "mixed_pred_ml:infer", - "metrics:train", - "metrics:validate", - "step_time", - ] - ) - - @property - def checkpoint_monitor_keys(self): - return {"valLoss": "val/losses_predictor_prediction_loss"} - - def __init__(self, model_registrar, cfg, log_writer, device, input_mappings={}): - super(AgentFormerTrafficModel, self).__init__( - model_registrar, cfg, log_writer, device, input_mappings=input_mappings - ) - - self.bu = batch_utils(rasterize_mode="none") - self.modality_shapes = self.bu.get_modality_shapes(cfg) - # assert modality_shapes["image"][0] == 15 - self.nets = nn.ModuleDict() - - self.bu = batch_utils(parse=True, rasterize_mode="none") - self.nets["policy"] = AgentFormer(cfg) - self.cfg = cfg - - if "checkpoint" in cfg and cfg.checkpoint["enabled"]: - checkpoint = torch.load(cfg.checkpoint["path"]) - predictor_dict = { - removeprefix(k, "components.predictor."): v - for k, v in checkpoint["state_dict"].items() - if k.startswith("components.predictor.") - } - self.load_state_dict(predictor_dict) - - self.device = device - self.nets["policy"] = self.nets["policy"].to(self.device) - - def set_eval(self): - self.nets["policy"].eval() - - def set_train(self): - self.nets["policy"].train() - - def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: - if isinstance(inputs["scene_batch"], dict) and "batch" in inputs["scene_batch"]: - parsed_batch = self.bu.parse_batch( - inputs["scene_batch"]["batch"].astype(torch.float) - ) - else: - parsed_batch = self.bu.parse_batch( - inputs["scene_batch"].astype(torch.float) - ) - - # tic = torch_utils.tic(timer=self.hyperparams.debug.timer) - parsed_batch_torch = { - k: v.to(self.device) - for k, v in parsed_batch.items() - if isinstance(v, torch.Tensor) - } - parsed_batch.update(parsed_batch_torch) - output = self.nets["policy"]( - parsed_batch, predict=(run_mode == RunMode.INFER), **kwargs - ) - # torch_utils.toc(tic, name="prediction model", timer=self.hyperparams.debug.timer) - - # Convert to standard prediction output - trajs_xyh = output["trajectories"] # b, Ne, mode, N_agent, t, xyh - trajs_xyh = trajs_xyh[:, 0] # b, K, Na, T, 3 - trajs_xyh = trajs_xyh.transpose(1, 2) # b, Na, K, T, 3 - trajs_xyh = trajs_xyh[:, :, None] # # b, Na, S=1, K, T, 3 - # Infer velocity from xy - dt = self.hyperparams["step_time"] # hyperparams is AgentFormerConfig - trajs_xyhv = traj_xyh_to_xyhv(trajs_xyh, dt) - trajs_xyhv = StateTensor.from_array(trajs_xyhv, format="x,y,h,v_lon") - - log_probs = torch.log(output["p"]) # b, Ne, mode - log_probs = log_probs[:, :1, None, :, None] # b, 1, S, K, T=1 - - mus = output["trajectories"][:, 0, ..., :2].permute( - 2, 0, 3, 1, 4 - ) # (Na,b,T,M,2) - Na, bs, Tf, M = mus.shape[:4] - log_pis = ( - torch.log(output["p"][None, :, None, 0]) - .repeat_interleave(Na, 0) - .repeat_interleave(Tf, 2) - ) - log_sigmas = torch.zeros_like(mus) - corrs = torch.zeros_like(log_pis) - y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) - if "state_trajectory" in output and output["state_trajectory"] is not None: - state_traj = output["state_trajectory"][None, :, 0] - # changing state order - state_traj = torch.cat( - [state_traj[..., :2], state_traj[..., 3:], state_traj[..., 2:3]], -1 - ) - output["pred_ml"] = state_traj - else: - output["pred_ml"] = output["trajectories"][None, :, 0] - output["pred_dist"] = y_dists - - output.update(dict(data_batch=parsed_batch)) - if run_mode == RunMode.INFER: - # Convert to standardized prediction output - dt = self.hyperparams["step_time"] # hyperparams is AgentFormerConfig - mus_xyh = output["trajectories"] # b, Ne, mode, N_agent, t, xyh - log_pis = torch.log(output["p"]) # b, Ne, mode - - # pred_dist: GMM - mus_xyh = mus_xyh[:, 0] # b, mode, N_agent, t, xyh - mus_xyh = mus_xyh.permute(2, 0, 1, 3, 4) # (N_agent, b, mode, T, xyh) - # Infer velocity from xy - mus_xyhv = traj_xyh_to_xyhv(mus_xyh, dt) - mus_xyhv = mus_xyhv.transpose(2, 3) # (N_agent, b, T, mode, xyhv) - - # Currently we simply treat joint distribtion as agent-wise marginals. - log_pis = ( - log_pis[:, 0] - .reshape(1, log_pis.shape[0], 1, log_pis.shape[2]) - .repeat(mus_xyhv.shape[0], 1, mus_xyhv.shape[2], 1) - ) # n, b, T, mode - log_sigmas = torch.log( - ( - torch.arange( - 1, - mus_xyhv.shape[2] + 1, - dtype=mus_xyhv.dtype, - device=mus_xyhv.device, - ) - * dt - ) - ** 2 - * 2 - ) - log_sigmas = log_sigmas.reshape(1, 1, mus_xyhv.shape[2], 1, 1).repeat( - (mus_xyhv.shape[0], mus_xyhv.shape[1], 1, mus_xyhv.shape[3], 2) - ) - corrs = 0.0 * torch.ones( - mus_xyhv.shape[:-1], dtype=mus_xyhv.dtype, device=mus_xyhv.device - ) - - pred_dist_with_ego = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) - - # drop ego - if isinstance(inputs["scene_batch"], SceneBatch): - assert (inputs["scene_batch"].extras["robot_ind"] <= 0).all() - pred_dist = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) - - ml_mode_ind = torch.argmax(log_pis, dim=-1) # n, b, T - # pred_ml = batch_select(mus_xyhv, ml_mode_ind, 3) # n, b, T, 4 - pred_ml = mus_xyhv.permute(1, 3, 0, 2, 4) - - # Dummy single agent prediction. - if isinstance(inputs["scene_batch"], SceneBatch): - agent_fut = inputs["scene_batch"].agent_fut - else: - agent_fut = inputs["scene_batch"]["agent_fut"] - pred_single = torch.full( - [agent_fut.shape[0], 0, agent_fut.shape[2], 4], - dtype=agent_fut.dtype, - device=agent_fut.device, - fill_value=torch.nan, - ) - - output["pred_dist"] = pred_dist - output["pred_dist_with_ego"] = pred_dist_with_ego - output["pred_ml"] = pred_ml - output["pred_single"] = pred_single - output["metrics"] = {} - else: - output["pred_dist"] = None - output["pred_dist_with_ego"] = None - output["pred_ml"] = None - output["pred_single"] = None - output["metrics"] = {} - output["step_time"] = self.cfg["step_time"] - return output - - def compute_losses(self, pred_batch, inputs): - return self.nets["policy"].compute_losses(pred_batch, None) - - def compute_metrics(self, pred_batch, data_batch): - EPS = 1e-3 - metrics = dict() - # calculate GT lane mode and homotopy - batch = pred_batch["data_batch"] - fut_mask = batch["fut_mask"] - mode_valid_flag = fut_mask.all(-1) - B, N, Tf = batch["agent_fut"].shape[:3] - traj = pred_batch["trajectories"].view(B, -1, N, Tf, 3) - if True: - lane_mask = batch["lane_mask"] - fut_xy = batch["agent_fut"][..., :2] - fut_sc = batch["agent_fut"][..., 6:8] - fut_sc = GeoUtils.normalize_sc(fut_sc) - fut_xysc = torch.cat([fut_xy, fut_sc], -1) - - end_points = fut_xysc[:, :, -1] # Only look at final time for GT! - lane_xyh = batch["lane_xyh"] - M = lane_xyh.size(1) - lane_xysc = torch.cat( - [ - lane_xyh[..., :2], - torch.sin(lane_xyh[..., 2:3]), - torch.cos(lane_xyh[..., 2:3]), - ], - -1, - ) - - GT_lane_mode, _ = SimpleLaneRelation.categorize_lane_relation_pts( - end_points.reshape(B * N, 1, 4), - lane_xysc.repeat_interleave(N, 0), - fut_mask.any(-1).reshape(B * N, 1), - lane_mask.repeat_interleave(N, 0), - force_select=False, - ) - # You could have two lanes that it is both on - - GT_lane_mode = GT_lane_mode.squeeze(-2).argmax(-1).reshape(B, N, M) - GT_lane_mode = torch.cat( - [GT_lane_mode, (GT_lane_mode == 0).all(-1, keepdim=True)], -1 - ) - - angle_diff, GT_homotopy = identify_pairwise_homotopy(fut_xy, mask=fut_mask) - GT_homotopy = GT_homotopy.type(torch.int64).reshape(B, N, N) - pred_batch["GT_lane_mode"] = GT_lane_mode - pred_batch["GT_homotopy"] = GT_homotopy - - pred_xysc = torch.cat( - [traj[..., :2], torch.sin(traj[..., 2:3]), torch.cos(traj[..., 2:3])], - -1, - ) - DS = pred_xysc.size(1) - - end_points = pred_xysc[:, :, :, -1] # Only look at final time - - pred_lane_mode, _ = SimpleLaneRelation.categorize_lane_relation_pts( - end_points.reshape(B * N * DS, 1, 4), - lane_xysc.repeat_interleave(N * DS, 0), - fut_mask.any(-1).repeat_interleave(DS, 0).reshape(B * DS * N, 1), - lane_mask.repeat_interleave(DS * N, 0), - force_select=False, - ) - # You could have two lanes that it is both on - - pred_lane_mode = pred_lane_mode.squeeze(-2).argmax(-1).reshape(B, DS, N, M) - pred_lane_mode = torch.cat( - [pred_lane_mode, (pred_lane_mode == 0).all(-1, keepdim=True)], -1 - ) - - angle_diff, pred_homotopy = identify_pairwise_homotopy( - pred_xysc[..., :2].view(B * DS, N, Tf, 2), - mask=fut_mask.repeat_interleave(DS, 0), - ) - pred_homotopy = pred_homotopy.type(torch.int64).reshape(B, DS, N, N) - ML_homotopy_flag = (pred_homotopy[:, 0] == GT_homotopy).all(-1) - - ML_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) - - metrics["ML_homotopy_correct_rate"] = TensorUtils.to_numpy( - (ML_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - all_homotopy_flag = (pred_homotopy == GT_homotopy[:, None]).any(1).all(-1) - all_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) - metrics["all_homotopy_correct_rate"] = TensorUtils.to_numpy( - (all_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - ML_lane_mode_flag = (pred_lane_mode[:, 0] == GT_lane_mode).all(-1) - all_lane_mode_flag = ( - (pred_lane_mode == GT_lane_mode[:, None]).any(1).all(-1) - ) - metrics["ML_lane_mode_correct_rate"] = TensorUtils.to_numpy( - (ML_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - metrics["all_lane_mode_correct_rate"] = TensorUtils.to_numpy( - (all_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - ML_scene_mode_flag = (pred_homotopy[:, 0] == GT_homotopy).all(-1) & ( - pred_lane_mode[:, 0] == GT_lane_mode - ).all(-1) - metrics["ML_scene_mode_correct_rate"] = TensorUtils.to_numpy( - (ML_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - all_scene_mode_flag = ( - (pred_homotopy == GT_homotopy[:, None]).all(-1) - & (pred_lane_mode == GT_lane_mode[:, None]).all(-1) - ).any(1) - metrics["all_scene_mode_correct_rate"] = TensorUtils.to_numpy( - (all_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() - ).item() - - if "GT_homotopy" in pred_batch: - # train/validation mode - - agent_fut, fut_mask, pred_traj = TensorUtils.to_numpy( - (batch["agent_fut"], batch["fut_mask"], pred_batch["trajectories"]) - ) - pred_traj = pred_traj.reshape([B, -1, N, Tf, 3]) - if pred_traj.shape[-2] != agent_fut.shape[-2]: - return metrics - a2a_valid_flag = mode_valid_flag.unsqueeze(-1) * mode_valid_flag.unsqueeze( - -2 - ) - - extent = batch["extent"] - DS = traj.size(1) - pred_edges = self.bu.generate_edges( - batch["agent_type"].repeat_interleave(DS, 0), - extent.repeat_interleave(DS, 0), - TensorUtils.join_dimensions(traj[..., :2], 0, 2), - TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), - batch_first=True, - ) - - coll_loss, dis = collision_loss(pred_edges=pred_edges, return_dis=True) - dis_padded = { - k: v.nan_to_num(1.0).view(B, DS, -1, Tf) for k, v in dis.items() - } - edge_mask = { - k: torch.logical_not(v.isnan().any(-1)).view(B, DS, -1) - for k, v in dis.items() - } - for k in dis_padded: - metrics["collision_rate_ML_" + k] = TensorUtils.to_numpy( - (dis_padded[k][:, 0] < 0).sum() - / (edge_mask[k][:, 0].sum() + EPS) - / Tf - ).item() - metrics[f"collision_rate_all_{DS}_mode_{k}"] = TensorUtils.to_numpy( - (dis_padded[k] < 0).sum() / (edge_mask[k].sum() + EPS) / Tf - ).item() - metrics["collision_rate_ML"] = TensorUtils.to_numpy( - sum([(v[:, 0] < 0).sum() for v in dis_padded.values()]) - / Tf - / (sum([edge_mask[k][:, 0].sum() for k in edge_mask]) + EPS) - ).item() - metrics[f"collision_rate_all_{DS}_mode"] = TensorUtils.to_numpy( - sum([(v < 0).sum() for v in dis_padded.values()]) - / Tf - / (sum([edge_mask[k].sum() for k in edge_mask]) + EPS) - ).item() - - confidence = np.ones([B * N, 1]) - Nmode = pred_traj.shape[1] - agent_type = TensorUtils.to_numpy(batch["agent_type"]) - vehicle_mask = ( - (agent_type == AgentType.VEHICLE) - | (agent_type == AgentType.BICYCLE) - | (agent_type == AgentType.MOTORCYCLE) - ) - pedestrian_mask = agent_type == AgentType.PEDESTRIAN - agent_mask = fut_mask.any(-1) - dt = self.cfg.step_time - for Tsecond in [3.0, 4.0, 5.0, 8.0]: - if Tf < Tsecond / dt: - continue - Tf_bar = int(Tsecond / dt) - - ADE = Metrics.batch_average_displacement_error( - agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), - pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), - confidence, - fut_mask.reshape(B * N, Tf)[..., :Tf_bar], - mode="mean", - ).reshape(B, N) - allADE = ADE.sum() / (agent_mask.sum() + EPS) - vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( - (vehicle_mask * agent_mask).sum() + EPS - ) - pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( - (pedestrian_mask * agent_mask).sum() + EPS - ) - FDE = Metrics.batch_final_displacement_error( - agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), - pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), - confidence, - fut_mask.reshape(B * N, Tf)[..., :Tf_bar], - mode="mean", - ).reshape(B, N) - - allFDE = FDE.sum() / (agent_mask.sum() + EPS) - vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( - (vehicle_mask * agent_mask).sum() + EPS - ) - pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( - (pedestrian_mask * agent_mask).sum() + EPS - ) - metrics[f"ML_ADE@{Tsecond}"] = allADE - metrics[f"ML_FDE@{Tsecond}"] = allFDE - metrics[f"ML_vehicle_ADE@{Tsecond}"] = vehADE - metrics[f"ML_vehicle_FDE@{Tsecond}"] = vehFDE - metrics[f"oracle_pedestrian_ADE@{Tsecond}"] = pedADE - metrics[f"oracle_pedestrian_FDE@{Tsecond}"] = pedFDE - - ADE = Metrics.batch_average_displacement_error( - agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), - pred_traj[..., :Tf_bar, :2] - .transpose(0, 2, 1, 3, 4) - .reshape(B * N, -1, Tf_bar, 2), - confidence.repeat(Nmode, 1) / Nmode, - fut_mask.reshape(B * N, Tf)[..., :Tf_bar], - mode="oracle", - ).reshape(B, N) - FDE = Metrics.batch_final_displacement_error( - agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), - pred_traj[..., :Tf_bar, :2] - .transpose(0, 2, 1, 3, 4) - .reshape(B * N, -1, Tf_bar, 2), - confidence.repeat(Nmode, 1) / Nmode, - fut_mask.reshape(B * N, Tf)[..., :Tf_bar], - mode="oracle", - ).reshape(B, N) - - allADE = ADE.sum() / (agent_mask.sum() + EPS) - vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( - (vehicle_mask * agent_mask).sum() + EPS - ) - pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( - (pedestrian_mask * agent_mask).sum() + EPS - ) - allFDE = FDE.sum() / (agent_mask.sum() + EPS) - vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( - (vehicle_mask * agent_mask).sum() + EPS - ) - pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( - (pedestrian_mask * agent_mask).sum() + EPS - ) - - metrics[f"min_ADE@{Tsecond}s"] = allADE - metrics[f"min_FDE@{Tsecond}s"] = allFDE - metrics[f"min_vehicle_ADE@{Tsecond}s"] = vehADE - metrics[f"min_vehicle_FDE@{Tsecond}s"] = vehFDE - metrics[f"min_pedestrian_ADE@{Tsecond}s"] = pedADE - metrics[f"min_pedestrian_FDE@{Tsecond}s"] = pedFDE - - return metrics From 4756de2d69749d740d9062b49bd424b908b7e733 Mon Sep 17 00:00:00 2001 From: yuxiaoc Date: Fri, 1 Dec 2023 19:37:48 -0800 Subject: [PATCH 10/10] fix typo in readme --- README.md | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c9bad7e..0ffeb6e 100644 --- a/README.md +++ b/README.md @@ -74,19 +74,17 @@ These additional steps might be necessary pip uninstall pathos -y pip install pathos==0.2.9 - # Fix opencv compatibility issue https://github.com/opencv/opencv-python/issues/591 pip uninstall opencv-python opencv-python-headless -y pip install "opencv-python-headless==4.2.0.34" # pip install "opencv-python-headless==4.7.0.72" # for python 3.9 # Sometimes you need to reinstall matplotlib with the correct version - pip install matplotlib==3.3.4 ``` -### Key files and code structure +## Key files and code structure Diffstack uses a similar config system as [TBSIM](https://github.com/NVlabs/traffic-behavior-simulation), where the config templates are first defined in python inside the [diffstack/configs](/diffstack/configs/) folder. We separate the configs for [data](/diffstack/configs/trajdata_config.py), [training](/diffstack/configs/base.py), and [models](/diffstack/configs/algo_config.py). @@ -102,17 +100,11 @@ The main files of CTT to look for is the [model file](/diffstack/models/CTT.py), We also included a rich collection of [utils functions](/diffstack/utils/), among which many are not used by CTT, but we believe they contribute to creating a convenient code base. -### Data +## Data CTT uses [trajdata](https://github.com/NVlabs/trajdata) as the dataloader, technically, you can train with any dataset supported by trajdata. Considering the vectorized map support, we have tested CTT with WOMD, nuScenes, and nuPlan. -## Generating config templates - -``` -python diffstack/scripts/generate_config_templates.py -``` - ## Training and eval The following examples use nuScenes trainval as dataset, you'll need to prepare the nuScenes dataset following instructions in [trajdata](https://github.com/NVlabs/trajdata).